diff --git a/PPOCRLabel/PPOCRLabel.py b/PPOCRLabel/PPOCRLabel.py index c17db91a5b5cd9d3cbb4b5bf6c87afd745d0870d..0a3ae1cb3b8fc004aa7c48dc86b6546a80e17a0f 100644 --- a/PPOCRLabel/PPOCRLabel.py +++ b/PPOCRLabel/PPOCRLabel.py @@ -2449,13 +2449,6 @@ class MainWindow(QMainWindow): export PPLabel and CSV to JSON (PubTabNet) ''' import pandas as pd - from libs.dataPartitionDialog import DataPartitionDialog - - # data partition user input - partitionDialog = DataPartitionDialog(parent=self) - partitionDialog.exec() - if partitionDialog.getStatus() == False: - return # automatically save annotations self.saveFilestate() @@ -2478,28 +2471,19 @@ class MainWindow(QMainWindow): labeldict[file] = eval(label) else: labeldict[file] = [] + + # read table recognition output + TableRec_excel_dir = os.path.join( + self.lastOpenDir, 'tableRec_excel_output') - train_split, val_split, test_split = partitionDialog.getDataPartition() - # check validate - if train_split + val_split + test_split > 100: - msg = "The sum of training, validation and testing data should be less than 100%" - QMessageBox.information(self, "Information", msg) - return - print(train_split, val_split, test_split) - train_split, val_split, test_split = float(train_split) / 100., float(val_split) / 100., float(test_split) / 100. - train_id = int(len(labeldict) * train_split) - val_id = int(len(labeldict) * (train_split + val_split)) - print('Data partition: train:', train_id, - 'validation:', val_id - train_id, - 'test:', len(labeldict) - val_id) - - TableRec_excel_dir = os.path.join(self.lastOpenDir, 'tableRec_excel_output') - json_results = [] - imgid = 0 + # save txt + fid = open( + "{}/gt.txt".format(self.lastOpenDir), "w", encoding='utf-8') for image_path in labeldict.keys(): # load csv annotations filename, _ = os.path.splitext(os.path.basename(image_path)) - csv_path = os.path.join(TableRec_excel_dir, filename + '.xlsx') + csv_path = os.path.join( + TableRec_excel_dir, filename + '.xlsx') if not os.path.exists(csv_path): continue @@ -2518,28 +2502,31 @@ class MainWindow(QMainWindow): cells = [] for anno in labeldict[image_path]: tokens = list(anno['transcription']) - obb = anno['points'] - hbb = OBB2HBB(np.array(obb)).tolist() - cells.append({'tokens': tokens, 'bbox': hbb}) - - # data split - if imgid < train_id: - split = 'train' - elif imgid < val_id: - split = 'val' - else: - split = 'test' - - # save dict - html = {'structure': {'tokens': token_list}, 'cell': cells} - json_results.append({'filename': os.path.basename(image_path), 'split': split, 'imgid': imgid, 'html': html}) - imgid += 1 - - # save json - with open("{}/annotation.json".format(self.lastOpenDir), "w", encoding='utf-8') as fid: - fid.write(json.dumps(json_results, ensure_ascii=False)) - - msg = 'JSON sucessfully saved in {}/annotation.json'.format(self.lastOpenDir) + cells.append({ + 'tokens': tokens, + 'bbox': anno['points'] + }) + + # 构造标注信息 + html = { + 'structure': { + 'tokens': token_list + }, + 'cells': cells + } + d = { + 'filename': os.path.basename(image_path), + 'html': html + } + # 重构HTML + d['gt'] = rebuild_html_from_ppstructure_label(d) + fid.write('{}\n'.format( + json.dumps( + d, ensure_ascii=False))) + + # convert to PP-Structure label format + fid.close() + msg = 'JSON sucessfully saved in {}/gt.txt'.format(self.lastOpenDir) QMessageBox.information(self, "Information", msg) def autolcm(self): @@ -2728,6 +2715,9 @@ class MainWindow(QMainWindow): self._update_shape_color(shape) self.keyDialog.addLabelHistory(key_text) + + # save changed shape + self.setDirty() def undoShapeEdit(self): self.canvas.restoreShape() diff --git a/PPOCRLabel/libs/canvas.py b/PPOCRLabel/libs/canvas.py index 81f37995126140b03650f5ddea37ea282d5ceb09..44d899cbc9f21793f89c498cf844c95e418b08a1 100644 --- a/PPOCRLabel/libs/canvas.py +++ b/PPOCRLabel/libs/canvas.py @@ -611,8 +611,8 @@ class Canvas(QWidget): if self.drawing() and not self.prevPoint.isNull() and not self.outOfPixmap(self.prevPoint): p.setPen(QColor(0, 0, 0)) - p.drawLine(self.prevPoint.x(), 0, self.prevPoint.x(), self.pixmap.height()) - p.drawLine(0, self.prevPoint.y(), self.pixmap.width(), self.prevPoint.y()) + p.drawLine(int(self.prevPoint.x()), 0, int(self.prevPoint.x()), self.pixmap.height()) + p.drawLine(0, int(self.prevPoint.y()), self.pixmap.width(), int(self.prevPoint.y())) self.setAutoFillBackground(True) if self.verified: @@ -909,4 +909,4 @@ class Canvas(QWidget): def updateShapeIndex(self): for i in range(len(self.shapes)): self.shapes[i].idx = i - self.update() \ No newline at end of file + self.update() diff --git a/PPOCRLabel/libs/dataPartitionDialog.py b/PPOCRLabel/libs/dataPartitionDialog.py deleted file mode 100644 index 33bd491552fe773bd07020d82f7ea9bab76e7557..0000000000000000000000000000000000000000 --- a/PPOCRLabel/libs/dataPartitionDialog.py +++ /dev/null @@ -1,113 +0,0 @@ -try: - from PyQt5.QtGui import * - from PyQt5.QtCore import * - from PyQt5.QtWidgets import * -except ImportError: - from PyQt4.QtGui import * - from PyQt4.QtCore import * - -from libs.utils import newIcon - -import time -import datetime -import json -import cv2 -import numpy as np - - -BB = QDialogButtonBox - -class DataPartitionDialog(QDialog): - def __init__(self, parent=None): - super().__init__() - self.parnet = parent - self.title = 'DATA PARTITION' - - self.train_ratio = 70 - self.val_ratio = 15 - self.test_ratio = 15 - - self.initUI() - - def initUI(self): - self.setWindowTitle(self.title) - self.setWindowModality(Qt.ApplicationModal) - - self.flag_accept = True - - if self.parnet.lang == 'ch': - msg = "导出JSON前请保存所有图像的标注且关闭EXCEL!" - else: - msg = "Please save all the annotations and close the EXCEL before exporting JSON!" - - info_msg = QLabel(msg, self) - info_msg.setWordWrap(True) - info_msg.setStyleSheet("color: red") - info_msg.setFont(QFont('Arial', 12)) - - train_lbl = QLabel('Train split: ', self) - train_lbl.setFont(QFont('Arial', 15)) - val_lbl = QLabel('Valid split: ', self) - val_lbl.setFont(QFont('Arial', 15)) - test_lbl = QLabel('Test split: ', self) - test_lbl.setFont(QFont('Arial', 15)) - - self.train_input = QLineEdit(self) - self.train_input.setFont(QFont('Arial', 15)) - self.val_input = QLineEdit(self) - self.val_input.setFont(QFont('Arial', 15)) - self.test_input = QLineEdit(self) - self.test_input.setFont(QFont('Arial', 15)) - - self.train_input.setText(str(self.train_ratio)) - self.val_input.setText(str(self.val_ratio)) - self.test_input.setText(str(self.test_ratio)) - - validator = QIntValidator(0, 100) - self.train_input.setValidator(validator) - self.val_input.setValidator(validator) - self.test_input.setValidator(validator) - - gridlayout = QGridLayout() - gridlayout.addWidget(info_msg, 0, 0, 1, 2) - gridlayout.addWidget(train_lbl, 1, 0) - gridlayout.addWidget(val_lbl, 2, 0) - gridlayout.addWidget(test_lbl, 3, 0) - gridlayout.addWidget(self.train_input, 1, 1) - gridlayout.addWidget(self.val_input, 2, 1) - gridlayout.addWidget(self.test_input, 3, 1) - - bb = BB(BB.Ok | BB.Cancel, Qt.Horizontal, self) - bb.button(BB.Ok).setIcon(newIcon('done')) - bb.button(BB.Cancel).setIcon(newIcon('undo')) - bb.accepted.connect(self.validate) - bb.rejected.connect(self.cancel) - gridlayout.addWidget(bb, 4, 0, 1, 2) - - self.setLayout(gridlayout) - - self.show() - - def validate(self): - self.flag_accept = True - self.accept() - - def cancel(self): - self.flag_accept = False - self.reject() - - def getStatus(self): - return self.flag_accept - - def getDataPartition(self): - self.train_ratio = int(self.train_input.text()) - self.val_ratio = int(self.val_input.text()) - self.test_ratio = int(self.test_input.text()) - - return self.train_ratio, self.val_ratio, self.test_ratio - - def closeEvent(self, event): - self.flag_accept = False - self.reject() - - diff --git a/PPOCRLabel/libs/utils.py b/PPOCRLabel/libs/utils.py index e397f139e0cf34de4fd517f920dd3fef12cc2cd7..1bd46ab4dac65f4e63e4ac4b2af5a8d295d89671 100644 --- a/PPOCRLabel/libs/utils.py +++ b/PPOCRLabel/libs/utils.py @@ -176,18 +176,6 @@ def boxPad(box, imgShape, pad : int) -> np.array: return box -def OBB2HBB(obb) -> np.array: - """ - Convert Oriented Bounding Box to Horizontal Bounding Box. - """ - hbb = np.zeros(4, dtype=np.int32) - hbb[0] = min(obb[:, 0]) - hbb[1] = min(obb[:, 1]) - hbb[2] = max(obb[:, 0]) - hbb[3] = max(obb[:, 1]) - return hbb - - def expand_list(merged, html_list): ''' Fill blanks according to merged cells @@ -232,6 +220,26 @@ def convert_token(html_list): return token_list +def rebuild_html_from_ppstructure_label(label_info): + from html import escape + html_code = label_info['html']['structure']['tokens'].copy() + to_insert = [ + i for i, tag in enumerate(html_code) if tag in ('', '>') + ] + for i, cell in zip(to_insert[::-1], label_info['html']['cells'][::-1]): + if cell['tokens']: + cell = [ + escape(token) if len(token) == 1 else token + for token in cell['tokens'] + ] + cell = ''.join(cell) + html_code.insert(i + 1, cell) + html_code = ''.join(html_code) + html_code = '{}
'.format( + html_code) + return html_code + + def stepsInfo(lang='en'): if lang == 'ch': msg = "1. 安装与运行:使用上述命令安装与运行程序。\n" \ diff --git "a/applications/\344\270\255\346\226\207\350\241\250\346\240\274\350\257\206\345\210\253.md" "b/applications/\344\270\255\346\226\207\350\241\250\346\240\274\350\257\206\345\210\253.md" new file mode 100644 index 0000000000000000000000000000000000000000..af7cc96b70410c614ef39e91c229d705c8bd400a --- /dev/null +++ "b/applications/\344\270\255\346\226\207\350\241\250\346\240\274\350\257\206\345\210\253.md" @@ -0,0 +1,472 @@ +# 智能运营:通用中文表格识别 + +- [1. 背景介绍](#1-背景介绍) +- [2. 中文表格识别](#2-中文表格识别) +- [2.1 环境准备](#21-环境准备) +- [2.2 准备数据集](#22-准备数据集) + - [2.2.1 划分训练测试集](#221-划分训练测试集) + - [2.2.2 查看数据集](#222-查看数据集) +- [2.3 训练](#23-训练) +- [2.4 验证](#24-验证) +- [2.5 训练引擎推理](#25-训练引擎推理) +- [2.6 模型导出](#26-模型导出) +- [2.7 预测引擎推理](#27-预测引擎推理) +- [2.8 表格识别](#28-表格识别) +- [3. 表格属性识别](#3-表格属性识别) +- [3.1 代码、环境、数据准备](#31-代码环境数据准备) + - [3.1.1 代码准备](#311-代码准备) + - [3.1.2 环境准备](#312-环境准备) + - [3.1.3 数据准备](#313-数据准备) +- [3.2 表格属性识别训练](#32-表格属性识别训练) +- [3.3 表格属性识别推理和部署](#33-表格属性识别推理和部署) + - [3.3.1 模型转换](#331-模型转换) + - [3.3.2 模型推理](#332-模型推理) + +## 1. 背景介绍 + +中文表格识别在金融行业有着广泛的应用,如保险理赔、财报分析和信息录入等领域。当前,金融行业的表格识别主要以手动录入为主,开发一种自动表格识别成为丞待解决的问题。 +![](https://ai-studio-static-online.cdn.bcebos.com/d1e7780f0c7745ada4be540decefd6288e4d59257d8141f6842682a4c05d28b6) + + +在金融行业中,表格图像主要有清单类的单元格密集型表格,申请表类的大单元格表格,拍照表格和倾斜表格四种主要形式。 + +![](https://ai-studio-static-online.cdn.bcebos.com/da82ae8ef8fd479aaa38e1049eb3a681cf020dc108fa458eb3ec79da53b45fd1) +![](https://ai-studio-static-online.cdn.bcebos.com/5ffff2093a144a6993a75eef71634a52276015ee43a04566b9c89d353198c746) + + +当前的表格识别算法不能很好的处理这些场景下的表格图像。在本例中,我们使用PP-Structurev2最新发布的表格识别模型SLANet来演示如何进行中文表格是识别。同时,为了方便作业流程,我们使用表格属性识别模型对表格图像的属性进行识别,对表格的难易程度进行判断,加快人工进行校对速度。 + +本项目AI Studio链接:https://aistudio.baidu.com/aistudio/projectdetail/4588067 + +## 2. 中文表格识别 +### 2.1 环境准备 + + +```python +# 下载PaddleOCR代码 +! git clone -b dygraph https://gitee.com/paddlepaddle/PaddleOCR +``` + + +```python +# 安装PaddleOCR环境 +! pip install -r PaddleOCR/requirements.txt --force-reinstall +! pip install protobuf==3.19 +``` + +### 2.2 准备数据集 + +本例中使用的数据集采用表格[生成工具](https://github.com/WenmuZhou/TableGeneration)制作。 + +使用如下命令对数据集进行解压,并查看数据集大小 + + +```python +! cd data/data165849 && tar -xf table_gen_dataset.tar && cd - +! wc -l data/data165849/table_gen_dataset/gt.txt +``` + +#### 2.2.1 划分训练测试集 + +使用下述命令将数据集划分为训练集和测试集, 这里将90%划分为训练集,10%划分为测试集 + + +```python +import random +with open('/home/aistudio/data/data165849/table_gen_dataset/gt.txt') as f: + lines = f.readlines() +random.shuffle(lines) +train_len = int(len(lines)*0.9) +train_list = lines[:train_len] +val_list = lines[train_len:] + +# 保存结果 +with open('/home/aistudio/train.txt','w',encoding='utf-8') as f: + f.writelines(train_list) +with open('/home/aistudio/val.txt','w',encoding='utf-8') as f: + f.writelines(val_list) +``` + +划分完成后,数据集信息如下 + +|类型|数量|图片地址|标注文件路径| +|---|---|---|---| +|训练集|18000|/home/aistudio/data/data165849/table_gen_dataset|/home/aistudio/train.txt| +|测试集|2000|/home/aistudio/data/data165849/table_gen_dataset|/home/aistudio/val.txt| + +#### 2.2.2 查看数据集 + + +```python +import cv2 +import os, json +import numpy as np +from matplotlib import pyplot as plt +%matplotlib inline + +def parse_line(data_dir, line): + data_line = line.strip("\n") + info = json.loads(data_line) + file_name = info['filename'] + cells = info['html']['cells'].copy() + structure = info['html']['structure']['tokens'].copy() + + img_path = os.path.join(data_dir, file_name) + if not os.path.exists(img_path): + print(img_path) + return None + data = { + 'img_path': img_path, + 'cells': cells, + 'structure': structure, + 'file_name': file_name + } + return data + +def draw_bbox(img_path, points, color=(255, 0, 0), thickness=2): + if isinstance(img_path, str): + img_path = cv2.imread(img_path) + img_path = img_path.copy() + for point in points: + cv2.polylines(img_path, [point.astype(int)], True, color, thickness) + return img_path + + +def rebuild_html(data): + html_code = data['structure'] + cells = data['cells'] + to_insert = [i for i, tag in enumerate(html_code) if tag in ('', '>')] + + for i, cell in zip(to_insert[::-1], cells[::-1]): + if cell['tokens']: + text = ''.join(cell['tokens']) + # skip empty text + sp_char_list = ['', '', '\u2028', ' ', '', ''] + text_remove_style = skip_char(text, sp_char_list) + if len(text_remove_style) == 0: + continue + html_code.insert(i + 1, text) + + html_code = ''.join(html_code) + return html_code + + +def skip_char(text, sp_char_list): + """ + skip empty cell + @param text: text in cell + @param sp_char_list: style char and special code + @return: + """ + for sp_char in sp_char_list: + text = text.replace(sp_char, '') + return text + +save_dir = '/home/aistudio/vis' +os.makedirs(save_dir, exist_ok=True) +image_dir = '/home/aistudio/data/data165849/' +html_str = '' + +# 解析标注信息并还原html表格 +data = parse_line(image_dir, val_list[0]) + +img = cv2.imread(data['img_path']) +img_name = ''.join(os.path.basename(data['file_name']).split('.')[:-1]) +img_save_name = os.path.join(save_dir, img_name) +boxes = [np.array(x['bbox']) for x in data['cells']] +show_img = draw_bbox(data['img_path'], boxes) +cv2.imwrite(img_save_name + '_show.jpg', show_img) + +html = rebuild_html(data) +html_str += html +html_str += '
' + +# 显示标注的html字符串 +from IPython.core.display import display, HTML +display(HTML(html_str)) +# 显示单元格坐标 +plt.figure(figsize=(15,15)) +plt.imshow(show_img) +plt.show() +``` + +### 2.3 训练 + +这里选用PP-Structurev2中的表格识别模型[SLANet](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/configs/table/SLANet.yml) + +SLANet是PP-Structurev2全新推出的表格识别模型,相比PP-Structurev1中TableRec-RARE,在速度不变的情况下精度提升4.7%。TEDS提升2% + + +|算法|Acc|[TEDS(Tree-Edit-Distance-based Similarity)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src)|Speed| +| --- | --- | --- | ---| +| EDD[2] |x| 88.3% |x| +| TableRec-RARE(ours) | 71.73%| 93.88% |779ms| +| SLANet(ours) | 76.31%| 95.89%|766ms| + +进行训练之前先使用如下命令下载预训练模型 + + +```python +# 进入PaddleOCR工作目录 +os.chdir('/home/aistudio/PaddleOCR') +# 下载英文预训练模型 +! wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/en_ppstructure_mobile_v2.0_SLANet_train.tar --no-check-certificate +! cd ./pretrain_models/ && tar xf en_ppstructure_mobile_v2.0_SLANet_train.tar && cd ../ +``` + +使用如下命令即可启动训练,需要修改的配置有 + +|字段|修改值|含义| +|---|---|---| +|Global.pretrained_model|./pretrain_models/en_ppstructure_mobile_v2.0_SLANet_train/best_accuracy.pdparams|指向英文表格预训练模型地址| +|Global.eval_batch_step|562|模型多少step评估一次,一般设置为一个epoch总的step数| +|Optimizer.lr.name|Const|学习率衰减器 | +|Optimizer.lr.learning_rate|0.0005|学习率设为之前的0.05倍 | +|Train.dataset.data_dir|/home/aistudio/data/data165849|指向训练集图片存放目录 | +|Train.dataset.label_file_list|/home/aistudio/data/data165849/table_gen_dataset/train.txt|指向训练集标注文件 | +|Train.loader.batch_size_per_card|32|训练时每张卡的batch_size | +|Train.loader.num_workers|1|训练集多进程数据读取的进程数,在aistudio中需要设为1 | +|Eval.dataset.data_dir|/home/aistudio/data/data165849|指向测试集图片存放目录 | +|Eval.dataset.label_file_list|/home/aistudio/data/data165849/table_gen_dataset/val.txt|指向测试集标注文件 | +|Eval.loader.batch_size_per_card|32|测试时每张卡的batch_size | +|Eval.loader.num_workers|1|测试集多进程数据读取的进程数,在aistudio中需要设为1 | + + +已经修改好的配置存储在 `/home/aistudio/SLANet_ch.yml` + + +```python +import os +os.chdir('/home/aistudio/PaddleOCR') +! python3 tools/train.py -c /home/aistudio/SLANet_ch.yml +``` + +大约在7个epoch后达到最高精度 97.49% + +### 2.4 验证 + +训练完成后,可使用如下命令在测试集上评估最优模型的精度 + + +```python +! python3 tools/eval.py -c /home/aistudio/SLANet_ch.yml -o Global.checkpoints=/home/aistudio/PaddleOCR/output/SLANet_ch/best_accuracy.pdparams +``` + +### 2.5 训练引擎推理 +使用如下命令可使用训练引擎对单张图片进行推理 + + +```python +import os;os.chdir('/home/aistudio/PaddleOCR') +! python3 tools/infer_table.py -c /home/aistudio/SLANet_ch.yml -o Global.checkpoints=/home/aistudio/PaddleOCR/output/SLANet_ch/best_accuracy.pdparams Global.infer_img=/home/aistudio/data/data165849/table_gen_dataset/img/no_border_18298_G7XZH93DDCMATGJQ8RW2.jpg +``` + + +```python +import cv2 +from matplotlib import pyplot as plt +%matplotlib inline + +# 显示原图 +show_img = cv2.imread('/home/aistudio/data/data165849/table_gen_dataset/img/no_border_18298_G7XZH93DDCMATGJQ8RW2.jpg') +plt.figure(figsize=(15,15)) +plt.imshow(show_img) +plt.show() + +# 显示预测的单元格 +show_img = cv2.imread('/home/aistudio/PaddleOCR/output/infer/no_border_18298_G7XZH93DDCMATGJQ8RW2.jpg') +plt.figure(figsize=(15,15)) +plt.imshow(show_img) +plt.show() +``` + +### 2.6 模型导出 + +使用如下命令可将模型导出为inference模型 + + +```python +! python3 tools/export_model.py -c /home/aistudio/SLANet_ch.yml -o Global.checkpoints=/home/aistudio/PaddleOCR/output/SLANet_ch/best_accuracy.pdparams Global.save_inference_dir=/home/aistudio/SLANet_ch/infer +``` + +### 2.7 预测引擎推理 +使用如下命令可使用预测引擎对单张图片进行推理 + + + +```python +os.chdir('/home/aistudio/PaddleOCR/ppstructure') +! python3 table/predict_structure.py \ + --table_model_dir=/home/aistudio/SLANet_ch/infer \ + --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \ + --image_dir=/home/aistudio/data/data165849/table_gen_dataset/img/no_border_18298_G7XZH93DDCMATGJQ8RW2.jpg \ + --output=../output/inference +``` + + +```python +# 显示原图 +show_img = cv2.imread('/home/aistudio/data/data165849/table_gen_dataset/img/no_border_18298_G7XZH93DDCMATGJQ8RW2.jpg') +plt.figure(figsize=(15,15)) +plt.imshow(show_img) +plt.show() + +# 显示预测的单元格 +show_img = cv2.imread('/home/aistudio/PaddleOCR/output/inference/no_border_18298_G7XZH93DDCMATGJQ8RW2.jpg') +plt.figure(figsize=(15,15)) +plt.imshow(show_img) +plt.show() +``` + +### 2.8 表格识别 + +在表格结构模型训练完成后,可结合OCR检测识别模型,对表格内容进行识别。 + +首先下载PP-OCRv3文字检测识别模型 + + +```python +# 下载PP-OCRv3文本检测识别模型并解压 +! wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_slim_infer.tar --no-check-certificate +! wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_slim_infer.tar --no-check-certificate +! cd ./inference/ && tar xf ch_PP-OCRv3_det_slim_infer.tar && tar xf ch_PP-OCRv3_rec_slim_infer.tar && cd ../ +``` + +模型下载完成后,使用如下命令进行表格识别 + + +```python +import os;os.chdir('/home/aistudio/PaddleOCR/ppstructure') +! python3 table/predict_table.py \ + --det_model_dir=inference/ch_PP-OCRv3_det_slim_infer \ + --rec_model_dir=inference/ch_PP-OCRv3_rec_slim_infer \ + --table_model_dir=/home/aistudio/SLANet_ch/infer \ + --rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt \ + --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \ + --image_dir=/home/aistudio/data/data165849/table_gen_dataset/img/no_border_18298_G7XZH93DDCMATGJQ8RW2.jpg \ + --output=../output/table +``` + + +```python +# 显示原图 +show_img = cv2.imread('/home/aistudio/data/data165849/table_gen_dataset/img/no_border_18298_G7XZH93DDCMATGJQ8RW2.jpg') +plt.figure(figsize=(15,15)) +plt.imshow(show_img) +plt.show() + +# 显示预测结果 +from IPython.core.display import display, HTML +display(HTML('
alleadersh不贰过,推从自己参与浙江数。另一方
AnSha自己越共商共建工作协商w.east 抓好改革试点任务
EdimeImisesElec怀天下”。22.26 31.614.30 794.94
ip Profundi:2019年12月1Horspro444.482.41 87679.98
iehaiTrain组长蒋蕊Toafterdec203.4323.54 44266.62
Tyint roudlyRol谢您的好意,我知道ErChows48.9010316
NaFlint一辈的aterreclam7823.869829.237.96 3068
家上下游企业,5Tr景象。当地球上的我们Urelaw799.62354.9612.9833
赛事( uestCh复制的业务模式并Listicjust9.239253.22
Ca Iskole扶贫"之名引导 Papua 7191.901.653.6248
避讳ir但由于Fficeof0.226.377.173397.75
ndaTurk百处遗址gMa1288.342053.662.29885.45
')) +``` + +## 3. 表格属性识别 +### 3.1 代码、环境、数据准备 +#### 3.1.1 代码准备 +首先,我们需要准备训练表格属性的代码,PaddleClas集成了PULC方案,该方案可以快速获得一个在CPU上用时2ms的属性识别模型。PaddleClas代码可以clone下载得到。获取方式如下: + + + +```python +! git clone -b develop https://gitee.com/paddlepaddle/PaddleClas +``` + +#### 3.1.2 环境准备 +其次,我们需要安装训练PaddleClas相关的依赖包 + + +```python +! pip install -r PaddleClas/requirements.txt --force-reinstall +! pip install protobuf==3.20.0 +``` + + +#### 3.1.3 数据准备 + +最后,准备训练数据。在这里,我们一共定义了表格的6个属性,分别是表格来源、表格数量、表格颜色、表格清晰度、表格有无干扰、表格角度。其可视化如下: + +![](https://user-images.githubusercontent.com/45199522/190587903-ccdfa6fb-51e8-42de-b08b-a127cb04e304.png) + +这里,我们提供了一个表格属性的demo子集,可以快速迭代体验。下载方式如下: + + +```python +%cd PaddleClas/dataset +!wget https://paddleclas.bj.bcebos.com/data/PULC/table_attribute.tar +!tar -xf table_attribute.tar +%cd ../PaddleClas/dataset +%cd ../ +``` + +### 3.2 表格属性识别训练 +表格属性训练整体pipelinie如下: + +![](https://user-images.githubusercontent.com/45199522/190599426-3415b38e-e16e-4e68-9253-2ff531b1b5ca.png) + +1.训练过程中,图片经过预处理之后,送入到骨干网络之中,骨干网络将抽取表格图片的特征,最终该特征连接输出的FC层,FC层经过Sigmoid激活函数后和真实标签做交叉熵损失函数,优化器通过对该损失函数做梯度下降来更新骨干网络的参数,经过多轮训练后,骨干网络的参数可以对为止图片做很好的预测; + +2.推理过程中,图片经过预处理之后,送入到骨干网络之中,骨干网络加载学习好的权重后对该表格图片做出预测,预测的结果为一个6维向量,该向量中的每个元素反映了每个属性对应的概率值,通过对该值进一步卡阈值之后,得到最终的输出,最终的输出描述了该表格的6个属性。 + +当准备好相关的数据之后,可以一键启动表格属性的训练,训练代码如下: + + +```python + +!python tools/train.py -c ./ppcls/configs/PULC/table_attribute/PPLCNet_x1_0.yaml -o Global.device=cpu -o Global.epochs=10 +``` + +### 3.3 表格属性识别推理和部署 +#### 3.3.1 模型转换 +当训练好模型之后,需要将模型转换为推理模型进行部署。转换脚本如下: + + +```python +!python tools/export_model.py -c ppcls/configs/PULC/table_attribute/PPLCNet_x1_0.yaml -o Global.pretrained_model=output/PPLCNet_x1_0/best_model +``` + +执行以上命令之后,会在当前目录上生成`inference`文件夹,该文件夹中保存了当前精度最高的推理模型。 + +#### 3.3.2 模型推理 +安装推理需要的paddleclas包, 此时需要通过下载安装paddleclas的develop的whl包 + + + +```python +!pip install https://paddleclas.bj.bcebos.com/whl/paddleclas-0.0.0-py3-none-any.whl +``` + +进入`deploy`目录下即可对模型进行推理 + + +```python +%cd deploy/ +``` + +推理命令如下: + + +```python +!python python/predict_cls.py -c configs/PULC/table_attribute/inference_table_attribute.yaml -o Global.inference_model_dir="../inference" -o Global.infer_imgs="../dataset/table_attribute/Table_val/val_9.jpg" +!python python/predict_cls.py -c configs/PULC/table_attribute/inference_table_attribute.yaml -o Global.inference_model_dir="../inference" -o Global.infer_imgs="../dataset/table_attribute/Table_val/val_3253.jpg" +``` + +推理的表格图片: + +![](https://user-images.githubusercontent.com/45199522/190596141-74f4feda-b082-46d7-908d-b0bd5839b430.png) + +预测结果如下: +``` +val_9.jpg: {'attributes': ['Scanned', 'Little', 'Black-and-White', 'Clear', 'Without-Obstacles', 'Horizontal'], 'output': [1, 1, 1, 1, 1, 1]} +``` + + +推理的表格图片: + +![](https://user-images.githubusercontent.com/45199522/190597086-2e685200-22d0-4042-9e46-f61f24e02e4e.png) + +预测结果如下: +``` +val_3253.jpg: {'attributes': ['Photo', 'Little', 'Black-and-White', 'Blurry', 'Without-Obstacles', 'Tilted'], 'output': [0, 1, 1, 0, 1, 0]} +``` + +对比两张图片可以发现,第一张图片比较清晰,表格属性的结果也偏向于比较容易识别,我们可以更相信表格识别的结果,第二张图片比较模糊,且存在倾斜现象,表格识别可能存在错误,需要我们人工进一步校验。通过表格的属性识别能力,可以进一步将“人工”和“智能”很好的结合起来,为表格识别能力的落地的精度提供保障。 diff --git "a/applications/\345\215\260\347\253\240\345\274\257\346\233\262\346\226\207\345\255\227\350\257\206\345\210\253.md" "b/applications/\345\215\260\347\253\240\345\274\257\346\233\262\346\226\207\345\255\227\350\257\206\345\210\253.md" new file mode 100644 index 0000000000000000000000000000000000000000..fce9ea772eed6575de10f50c0ff447aa1aee928b --- /dev/null +++ "b/applications/\345\215\260\347\253\240\345\274\257\346\233\262\346\226\207\345\255\227\350\257\206\345\210\253.md" @@ -0,0 +1,1033 @@ +# 印章弯曲文字识别 + +- [1. 项目介绍](#1-----) +- [2. 环境搭建](#2-----) + * [2.1 准备PaddleDetection环境](#21---paddledetection--) + * [2.2 准备PaddleOCR环境](#22---paddleocr--) +- [3. 数据集准备](#3------) + * [3.1 数据标注](#31-----) + * [3.2 数据处理](#32-----) +- [4. 印章检测实践](#4-------) +- [5. 印章文字识别实践](#5---------) + * [5.1 端对端印章文字识别实践](#51------------) + * [5.2 两阶段印章文字识别实践](#52------------) + + [5.2.1 印章文字检测](#521-------) + + [5.2.2 印章文字识别](#522-------) + + +# 1. 项目介绍 + +弯曲文字识别在OCR任务中有着广泛的应用,比如:自然场景下的招牌,艺术文字,以及常见的印章文字识别。 + +在本项目中,将以印章识别任务为例,介绍如何使用PaddleDetection和PaddleOCR完成印章检测和印章文字识别任务。 + +项目难点: +1. 缺乏训练数据 +2. 图像质量参差不齐,图像模糊,文字不清晰 + +针对以上问题,本项目选用PaddleOCR里的PPOCRLabel工具完成数据标注。基于PaddleDetection完成印章区域检测,然后通过PaddleOCR里的端对端OCR算法和两阶段OCR算法分别完成印章文字识别任务。不同任务的精度效果如下: + + +| 任务 | 训练数据数量 | 精度 | +| -------- | - | -------- | +| 印章检测 | 1000 | 95% | +| 印章文字识别-端对端OCR方法 | 700 | 47% | +| 印章文字识别-两阶段OCR方法 | 700 | 55% | + +点击进入 [AI Studio 项目](https://aistudio.baidu.com/aistudio/projectdetail/4586113) + +# 2. 环境搭建 + +本项目需要准备PaddleDetection和PaddleOCR的项目运行环境,其中PaddleDetection用于实现印章检测任务,PaddleOCR用于实现文字识别任务 + + +## 2.1 准备PaddleDetection环境 + +下载PaddleDetection代码: +``` +!git clone https://github.com/PaddlePaddle/PaddleDetection.git +# 如果克隆github代码较慢,请从gitee上克隆代码 +#git clone https://gitee.com/PaddlePaddle/PaddleDetection.git +``` + +安装PaddleDetection依赖 +``` +!cd PaddleDetection && pip install -r requirements.txt +``` + +## 2.2 准备PaddleOCR环境 + +下载PaddleOCR代码: +``` +!git clone https://github.com/PaddlePaddle/PaddleOCR.git +# 如果克隆github代码较慢,请从gitee上克隆代码 +#git clone https://gitee.com/PaddlePaddle/PaddleOCR.git +``` + +安装PaddleOCR依赖 +``` +!cd PaddleOCR && git checkout dygraph && pip install -r requirements.txt +``` + +# 3. 数据集准备 + +## 3.1 数据标注 + +本项目中使用[PPOCRLabel](https://github.com/PaddlePaddle/PaddleOCR/tree/release/2.6/PPOCRLabel)工具标注印章检测数据,标注内容包括印章的位置以及印章中文字的位置和文字内容。 + + +注:PPOCRLabel的使用方法参考[文档](https://github.com/PaddlePaddle/PaddleOCR/tree/release/2.6/PPOCRLabel)。 + +PPOCRlabel标注印章数据步骤: +- 打开数据集所在文件夹 +- 按下快捷键Q进行4点(多点)标注——针对印章文本识别, + - 印章弯曲文字包围框采用偶数点标注(比如4点,8点,16点),按照阅读顺序,以16点标注为例,从文字左上方开始标注->到文字右上方标注8个点->到文字右下方->文字左下方8个点,一共8个点,形成包围曲线,参考下图。如果文字弯曲程度不高,为了减小标注工作量,可以采用4点、8点标注,需要注意的是,文字上下点数相同。(总点数尽量不要超过18个) + - 对于需要识别的印章中非弯曲文字,采用4点框标注即可 + - 对应包围框的文字部分默认是”待识别”,需要修改为包围框内的具体文字内容 +- 快捷键W进行矩形标注——针对印章区域检测,印章检测区域保证标注框包围整个印章,包围框对应文字可以设置为'印章区域',方便后续处理。 +- 针对印章中的水平文字可以视情况考虑矩形或四点标注:保证按行标注即可。如果背景文字与印章文字比较接近,标注时尽量避开背景文字。 +- 标注完成后修改右侧文本结果,确认无误后点击下方check(或CTRL+V),确认本张图片的标注。 +- 所有图片标注完成后,在顶部菜单栏点击File -> Export Label导出label.txt。 + +标注完成后,可视化效果如下: +![](https://ai-studio-static-online.cdn.bcebos.com/f5acbc4f50dd401a8f535ed6a263f94b0edff82c1aed4285836a9ead989b9c13) + +数据标注完成后,标签中包含印章检测的标注和印章文字识别的标注,如下所示: +``` +img/1.png [{"transcription": "印章区域", "points": [[87, 245], [214, 245], [214, 369], [87, 369]], "difficult": false}, {"transcription": "国家税务总局泸水市税务局第二税务分局", "points": [[110, 314], [116, 290], [131, 275], [152, 273], [170, 277], [181, 289], [186, 303], [186, 312], [201, 311], [198, 289], [189, 272], [175, 259], [152, 252], [124, 257], [100, 280], [94, 312]], "difficult": false}, {"transcription": "征税专用章", "points": [[117, 334], [183, 334], [183, 352], [117, 352]], "difficult": false}] +``` +标注中包含表示'印章区域'的坐标和'印章文字'坐标以及文字内容。 + + + +## 3.2 数据处理 + +标注时为了方便标注,没有区分印章区域的标注框和文字区域的标注框,可以通过python代码完成标签的划分。 + +在本项目的'/home/aistudio/work/seal_labeled_datas'目录下,存放了标注的数据示例,如下: + + +![](https://ai-studio-static-online.cdn.bcebos.com/3d762970e2184177a2c633695a31029332a4cd805631430ea797309492e45402) + +标签文件'/home/aistudio/work/seal_labeled_datas/Label.txt'中的标注内容如下: + +``` +img/test1.png [{"transcription": "待识别", "points": [[408, 232], [537, 232], [537, 352], [408, 352]], "difficult": false}, {"transcription": "电子回单", "points": [[437, 305], [504, 305], [504, 322], [437, 322]], "difficult": false}, {"transcription": "云南省农村信用社", "points": [[417, 290], [434, 295], [438, 281], [446, 267], [455, 261], [472, 258], [489, 264], [498, 277], [502, 295], [526, 289], [518, 267], [503, 249], [475, 232], [446, 239], [429, 255], [418, 275]], "difficult": false}, {"transcription": "专用章", "points": [[437, 319], [503, 319], [503, 338], [437, 338]], "difficult": false}] +``` + + +为了方便训练,我们需要通过python代码将用于训练印章检测和训练印章文字识别的标注区分开。 + + +``` +import numpy as np +import json +import cv2 +import os +from shapely.geometry import Polygon + + +def poly2box(poly): + xmin = np.min(np.array(poly)[:, 0]) + ymin = np.min(np.array(poly)[:, 1]) + xmax = np.max(np.array(poly)[:, 0]) + ymax = np.max(np.array(poly)[:, 1]) + return np.array([[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]]) + + +def draw_text_det_res(dt_boxes, src_im, color=(255, 255, 0)): + for box in dt_boxes: + box = np.array(box).astype(np.int32).reshape(-1, 2) + cv2.polylines(src_im, [box], True, color=color, thickness=2) + return src_im + +class LabelDecode(object): + def __init__(self, **kwargs): + pass + + def __call__(self, data): + label = json.loads(data['label']) + + nBox = len(label) + seal_boxes = self.get_seal_boxes(label) + + gt_label = [] + + for seal_box in seal_boxes: + seal_anno = {'seal_box': seal_box} + boxes, txts, txt_tags = [], [], [] + + for bno in range(0, nBox): + box = label[bno]['points'] + txt = label[bno]['transcription'] + try: + ints = self.get_intersection(box, seal_box) + except Exception as E: + print(E) + continue + + if abs(Polygon(box).area - self.get_intersection(box, seal_box)) < 1e-3 and \ + abs(Polygon(box).area - self.get_union(box, seal_box)) > 1e-3: + + boxes.append(box) + txts.append(txt) + if txt in ['*', '###', '待识别']: + txt_tags.append(True) + else: + txt_tags.append(False) + + seal_anno['polys'] = boxes + seal_anno['texts'] = txts + seal_anno['ignore_tags'] = txt_tags + + gt_label.append(seal_anno) + + return gt_label + + def get_seal_boxes(self, label): + + nBox = len(label) + seal_box = [] + for bno in range(0, nBox): + box = label[bno]['points'] + if len(box) == 4: + seal_box.append(box) + + if len(seal_box) == 0: + return None + + seal_box = self.valid_seal_box(seal_box) + return seal_box + + + def is_seal_box(self, box, boxes): + is_seal = True + for poly in boxes: + if list(box.shape()) != list(box.shape.shape()): + if abs(Polygon(box).area - self.get_intersection(box, poly)) < 1e-3: + return False + else: + if np.sum(np.array(box) - np.array(poly)) < 1e-3: + # continue when the box is same with poly + continue + if abs(Polygon(box).area - self.get_intersection(box, poly)) < 1e-3: + return False + return is_seal + + + def valid_seal_box(self, boxes): + if len(boxes) == 1: + return boxes + + new_boxes = [] + flag = True + for k in range(0, len(boxes)): + flag = True + tmp_box = boxes[k] + for i in range(0, len(boxes)): + if k == i: continue + if abs(Polygon(tmp_box).area - self.get_intersection(tmp_box, boxes[i])) < 1e-3: + flag = False + continue + if flag: + new_boxes.append(tmp_box) + + return new_boxes + + + def get_union(self, pD, pG): + return Polygon(pD).union(Polygon(pG)).area + + def get_intersection_over_union(self, pD, pG): + return get_intersection(pD, pG) / get_union(pD, pG) + + def get_intersection(self, pD, pG): + return Polygon(pD).intersection(Polygon(pG)).area + + def expand_points_num(self, boxes): + max_points_num = 0 + for box in boxes: + if len(box) > max_points_num: + max_points_num = len(box) + ex_boxes = [] + for box in boxes: + ex_box = box + [box[-1]] * (max_points_num - len(box)) + ex_boxes.append(ex_box) + return ex_boxes + + +def gen_extract_label(data_dir, label_file, seal_gt, seal_ppocr_gt): + label_decode_func = LabelDecode() + gts = open(label_file, "r").readlines() + + seal_gt_list = [] + seal_ppocr_list = [] + + for idx, line in enumerate(gts): + img_path, label = line.strip().split("\t") + data = {'label': label, 'img_path':img_path} + res = label_decode_func(data) + src_img = cv2.imread(os.path.join(data_dir, img_path)) + if res is None: + print("ERROR! res is None!") + continue + + anno = [] + for i, gt in enumerate(res): + # print(i, box, type(box), ) + anno.append({'polys': gt['seal_box'], 'cls':1}) + + seal_gt_list.append(f"{img_path}\t{json.dumps(anno)}\n") + seal_ppocr_list.append(f"{img_path}\t{json.dumps(res)}\n") + + if not os.path.exists(os.path.dirname(seal_gt)): + os.makedirs(os.path.dirname(seal_gt)) + if not os.path.exists(os.path.dirname(seal_ppocr_gt)): + os.makedirs(os.path.dirname(seal_ppocr_gt)) + + with open(seal_gt, "w") as f: + f.writelines(seal_gt_list) + f.close() + + with open(seal_ppocr_gt, 'w') as f: + f.writelines(seal_ppocr_list) + f.close() + +def vis_seal_ppocr(data_dir, label_file, save_dir): + + datas = open(label_file, 'r').readlines() + for idx, line in enumerate(datas): + img_path, label = line.strip().split('\t') + img_path = os.path.join(data_dir, img_path) + + label = json.loads(label) + src_im = cv2.imread(img_path) + if src_im is None: + continue + + for anno in label: + seal_box = anno['seal_box'] + txt_boxes = anno['polys'] + + # vis seal box + src_im = draw_text_det_res([seal_box], src_im, color=(255, 255, 0)) + src_im = draw_text_det_res(txt_boxes, src_im, color=(255, 0, 0)) + + save_path = os.path.join(save_dir, os.path.basename(img_path)) + if not os.path.exists(save_dir): + os.makedirs(save_dir) + # print(src_im.shape) + cv2.imwrite(save_path, src_im) + + +def draw_html(img_dir, save_name): + import glob + + images_dir = glob.glob(img_dir + "/*") + print(len(images_dir)) + + html_path = save_name + with open(html_path, 'w') as html: + html.write('\n\n') + html.write('\n') + html.write("") + + html.write("\n") + html.write(f'\n") + html.write(f'' % (base)) + html.write("\n") + + html.write('\n') + html.write('
\n GT') + + for i, filename in enumerate(sorted(images_dir)): + if filename.endswith("txt"): continue + print(filename) + + base = "{}".format(filename) + if True: + html.write("
{filename}\n GT') + html.write('GT 310\n
\n') + html.write('\n\n') + print("ok") + + +def crop_seal_from_img(label_file, data_dir, save_dir, save_gt_path): + + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + datas = open(label_file, 'r').readlines() + all_gts = [] + count = 0 + for idx, line in enumerate(datas): + img_path, label = line.strip().split('\t') + img_path = os.path.join(data_dir, img_path) + + label = json.loads(label) + src_im = cv2.imread(img_path) + if src_im is None: + continue + + for c, anno in enumerate(label): + seal_poly = anno['seal_box'] + txt_boxes = anno['polys'] + txts = anno['texts'] + ignore_tags = anno['ignore_tags'] + + box = poly2box(seal_poly) + img_crop = src_im[box[0][1]:box[2][1], box[0][0]:box[2][0], :] + + save_path = os.path.join(save_dir, f"{idx}_{c}.jpg") + cv2.imwrite(save_path, np.array(img_crop)) + + img_gt = [] + for i in range(len(txts)): + txt_boxes_crop = np.array(txt_boxes[i]) + txt_boxes_crop[:, 1] -= box[0, 1] + txt_boxes_crop[:, 0] -= box[0, 0] + img_gt.append({'transcription': txts[i], "points": txt_boxes_crop.tolist(), "ignore_tag": ignore_tags[i]}) + + if len(img_gt) >= 1: + count += 1 + save_gt = f"{os.path.basename(save_path)}\t{json.dumps(img_gt)}\n" + + all_gts.append(save_gt) + + print(f"The num of all image: {len(all_gts)}, and the number of useful image: {count}") + if not os.path.exists(os.path.dirname(save_gt_path)): + os.makedirs(os.path.dirname(save_gt_path)) + + with open(save_gt_path, "w") as f: + f.writelines(all_gts) + f.close() + print("Done") + + + +if __name__ == "__main__": + + # 数据处理 + gen_extract_label("./seal_labeled_datas", "./seal_labeled_datas/Label.txt", "./seal_ppocr_gt/seal_det_img.txt", "./seal_ppocr_gt/seal_ppocr_img.txt") + vis_seal_ppocr("./seal_labeled_datas", "./seal_ppocr_gt/seal_ppocr_img.txt", "./seal_ppocr_gt/seal_ppocr_vis/") + draw_html("./seal_ppocr_gt/seal_ppocr_vis/", "./vis_seal_ppocr.html") + seal_ppocr_img_label = "./seal_ppocr_gt/seal_ppocr_img.txt" + crop_seal_from_img(seal_ppocr_img_label, "./seal_labeled_datas/", "./seal_img_crop", "./seal_img_crop/label.txt") + +``` + +处理完成后,生成的文件如下: +``` +├── seal_img_crop/ +│ ├── 0_0.jpg +│ ├── ... +│ └── label.txt +├── seal_ppocr_gt/ +│ ├── seal_det_img.txt +│ ├── seal_ppocr_img.txt +│ └── seal_ppocr_vis/ +│ ├── test1.png +│ ├── ... +└── vis_seal_ppocr.html + +``` +其中`seal_img_crop/label.txt`文件为印章识别标签文件,其内容格式为: +``` +0_0.jpg [{"transcription": "\u7535\u5b50\u56de\u5355", "points": [[29, 73], [96, 73], [96, 90], [29, 90]], "ignore_tag": false}, {"transcription": "\u4e91\u5357\u7701\u519c\u6751\u4fe1\u7528\u793e", "points": [[9, 58], [26, 63], [30, 49], [38, 35], [47, 29], [64, 26], [81, 32], [90, 45], [94, 63], [118, 57], [110, 35], [95, 17], [67, 0], [38, 7], [21, 23], [10, 43]], "ignore_tag": false}, {"transcription": "\u4e13\u7528\u7ae0", "points": [[29, 87], [95, 87], [95, 106], [29, 106]], "ignore_tag": false}] +``` +可以直接用于PaddleOCR的PGNet算法的训练。 + +`seal_ppocr_gt/seal_det_img.txt`为印章检测标签文件,其内容格式为: +``` +img/test1.png [{"polys": [[408, 232], [537, 232], [537, 352], [408, 352]], "cls": 1}] +``` +为了使用PaddleDetection工具完成印章检测模型的训练,需要将`seal_det_img.txt`转换为COCO或者VOC的数据标注格式。 + +可以直接使用下述代码将印章检测标注转换成VOC格式。 + + +``` +import numpy as np +import json +import cv2 +import os +from shapely.geometry import Polygon + +seal_train_gt = "./seal_ppocr_gt/seal_det_img.txt" +# 注:仅用于示例,实际使用中需要分别转换训练集和测试集的标签 +seal_valid_gt = "./seal_ppocr_gt/seal_det_img.txt" + +def gen_main_train_txt(mode='train'): + if mode == "train": + file_path = seal_train_gt + if mode in ['valid', 'test']: + file_path = seal_valid_gt + + save_path = f"./seal_VOC/ImageSets/Main/{mode}.txt" + save_train_path = f"./seal_VOC/{mode}.txt" + if not os.path.exists(os.path.dirname(save_path)): + os.makedirs(os.path.dirname(save_path)) + + datas = open(file_path, 'r').readlines() + img_names = [] + train_names = [] + for line in datas: + img_name = line.strip().split('\t')[0] + img_name = os.path.basename(img_name) + (i_name, extension) = os.path.splitext(img_name) + t_name = 'JPEGImages/'+str(img_name)+' '+'Annotations/'+str(i_name)+'.xml\n' + train_names.append(t_name) + img_names.append(i_name + "\n") + + with open(save_train_path, "w") as f: + f.writelines(train_names) + f.close() + + with open(save_path, "w") as f: + f.writelines(img_names) + f.close() + + print(f"{mode} save done") + + +def gen_xml_label(mode='train'): + if mode == "train": + file_path = seal_train_gt + if mode in ['valid', 'test']: + file_path = seal_valid_gt + + datas = open(file_path, 'r').readlines() + img_names = [] + train_names = [] + anno_path = "./seal_VOC/Annotations" + img_path = "./seal_VOC/JPEGImages" + + if not os.path.exists(anno_path): + os.makedirs(anno_path) + if not os.path.exists(img_path): + os.makedirs(img_path) + + for idx, line in enumerate(datas): + img_name, label = line.strip().split('\t') + img = cv2.imread(os.path.join("./seal_labeled_datas", img_name)) + cv2.imwrite(os.path.join(img_path, os.path.basename(img_name)), img) + height, width, c = img.shape + img_name = os.path.basename(img_name) + (i_name, extension) = os.path.splitext(img_name) + label = json.loads(label) + + xml_file = open(("./seal_VOC/Annotations" + '/' + i_name + '.xml'), 'w') + xml_file.write('\n') + xml_file.write(' seal_VOC\n') + xml_file.write(' ' + str(img_name) + '\n') + xml_file.write(' ' + 'Annotations/' + str(img_name) + '\n') + xml_file.write(' \n') + xml_file.write(' ' + str(width) + '\n') + xml_file.write(' ' + str(height) + '\n') + xml_file.write(' 3\n') + xml_file.write(' \n') + xml_file.write(' 0\n') + + for anno in label: + poly = anno['polys'] + if anno['cls'] == 1: + gt_cls = 'redseal' + xmin = np.min(np.array(poly)[:, 0]) + ymin = np.min(np.array(poly)[:, 1]) + xmax = np.max(np.array(poly)[:, 0]) + ymax = np.max(np.array(poly)[:, 1]) + xmin,ymin,xmax,ymax= int(xmin),int(ymin),int(xmax),int(ymax) + xml_file.write(' \n') + xml_file.write(' '+str(gt_cls)+'\n') + xml_file.write(' Unspecified\n') + xml_file.write(' 0\n') + xml_file.write(' 0\n') + xml_file.write(' \n') + xml_file.write(' '+str(xmin)+'\n') + xml_file.write(' '+str(ymin)+'\n') + xml_file.write(' '+str(xmax)+'\n') + xml_file.write(' '+str(ymax)+'\n') + xml_file.write(' \n') + xml_file.write(' \n') + xml_file.write('') + xml_file.close() + print(f'{mode} xml save done!') + + +gen_main_train_txt() +gen_main_train_txt('valid') +gen_xml_label('train') +gen_xml_label('valid') + +``` + +数据处理完成后,转换为VOC格式的印章检测数据存储在~/data/seal_VOC目录下,目录组织结构为: + +``` +├── Annotations/ +├── ImageSets/ +│   └── Main/ +│   ├── train.txt +│   └── valid.txt +├── JPEGImages/ +├── train.txt +└── valid.txt +└── label_list.txt +``` + +Annotations下为数据的标签,JPEGImages目录下为图像文件,label_list.txt为标注检测框类别标签文件。 + +在接下来一节中,将介绍如何使用PaddleDetection工具库完成印章检测模型的训练。 + +# 4. 印章检测实践 + +在实际应用中,印章多是出现在合同,发票,公告等场景中,印章文字识别的任务需要排除图像中背景文字的影响,因此需要先检测出图像中的印章区域。 + + +借助PaddleDetection目标检测库可以很容易的实现印章检测任务,使用PaddleDetection训练印章检测任务流程如下: + +- 选择算法 +- 修改数据集配置路径 +- 启动训练 + + +**算法选择** + +PaddleDetection中有许多检测算法可以选择,考虑到每条数据中印章区域较为清晰,且考虑到性能需求。在本项目中,我们采用mobilenetv3为backbone的ppyolo算法完成印章检测任务,对应的配置文件是:configs/ppyolo/ppyolo_mbv3_large.yml + + + +**修改配置文件** + +配置文件中的默认数据路径是COCO, +需要修改为印章检测的数据路径,主要修改如下: +在配置文件'configs/ppyolo/ppyolo_mbv3_large.yml'末尾增加如下内容: +``` +metric: VOC +map_type: 11point +num_classes: 2 + +TrainDataset: + !VOCDataSet + dataset_dir: dataset/seal_VOC + anno_path: train.txt + label_list: label_list.txt + data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult'] + +EvalDataset: + !VOCDataSet + dataset_dir: dataset/seal_VOC + anno_path: test.txt + label_list: label_list.txt + data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult'] + +TestDataset: + !ImageFolder + anno_path: dataset/seal_VOC/label_list.txt +``` + +配置文件中设置的数据路径在PaddleDetection/dataset目录下,我们可以将处理后的印章检测训练数据移动到PaddleDetection/dataset目录下或者创建一个软连接。 + +``` +!ln -s seal_VOC ./PaddleDetection/dataset/ +``` + +另外图象中印章数量比较少,可以调整NMS后处理的检测框数量,即keep_top_k,nms_top_k 从100,1000,调整为10,100。在配置文件'configs/ppyolo/ppyolo_mbv3_large.yml'末尾增加如下内容完成后处理参数的调整 +``` +BBoxPostProcess: + decode: + name: YOLOBox + conf_thresh: 0.005 + downsample_ratio: 32 + clip_bbox: true + scale_x_y: 1.05 + nms: + name: MultiClassNMS + keep_top_k: 10 # 修改前100 + nms_threshold: 0.45 + nms_top_k: 100 # 修改前1000 + score_threshold: 0.005 +``` + + +修改完成后,需要在PaddleDetection中增加印章数据的处理代码,即在PaddleDetection/ppdet/data/source/目录下创建seal.py文件,文件中填充如下代码: +``` +import os +import numpy as np +from ppdet.core.workspace import register, serializable +from .dataset import DetDataset +import cv2 +import json + +from ppdet.utils.logger import setup_logger +logger = setup_logger(__name__) + + +@register +@serializable +class SealDataSet(DetDataset): + """ + Load dataset with COCO format. + + Args: + dataset_dir (str): root directory for dataset. + image_dir (str): directory for images. + anno_path (str): coco annotation file path. + data_fields (list): key name of data dictionary, at least have 'image'. + sample_num (int): number of samples to load, -1 means all. + load_crowd (bool): whether to load crowded ground-truth. + False as default + allow_empty (bool): whether to load empty entry. False as default + empty_ratio (float): the ratio of empty record number to total + record's, if empty_ratio is out of [0. ,1.), do not sample the + records and use all the empty entries. 1. as default + """ + + def __init__(self, + dataset_dir=None, + image_dir=None, + anno_path=None, + data_fields=['image'], + sample_num=-1, + load_crowd=False, + allow_empty=False, + empty_ratio=1.): + super(SealDataSet, self).__init__(dataset_dir, image_dir, anno_path, + data_fields, sample_num) + self.load_image_only = False + self.load_semantic = False + self.load_crowd = load_crowd + self.allow_empty = allow_empty + self.empty_ratio = empty_ratio + + def _sample_empty(self, records, num): + # if empty_ratio is out of [0. ,1.), do not sample the records + if self.empty_ratio < 0. or self.empty_ratio >= 1.: + return records + import random + sample_num = min( + int(num * self.empty_ratio / (1 - self.empty_ratio)), len(records)) + records = random.sample(records, sample_num) + return records + + def parse_dataset(self): + anno_path = os.path.join(self.dataset_dir, self.anno_path) + image_dir = os.path.join(self.dataset_dir, self.image_dir) + + records = [] + empty_records = [] + ct = 0 + + assert anno_path.endswith('.txt'), \ + 'invalid seal_gt file: ' + anno_path + + all_datas = open(anno_path, 'r').readlines() + + for idx, line in enumerate(all_datas): + im_path, label = line.strip().split('\t') + img_path = os.path.join(image_dir, im_path) + label = json.loads(label) + im_h, im_w, im_c = cv2.imread(img_path).shape + + coco_rec = { + 'im_file': img_path, + 'im_id': np.array([idx]), + 'h': im_h, + 'w': im_w, + } if 'image' in self.data_fields else {} + + if not self.load_image_only: + bboxes = [] + for anno in label: + poly = anno['polys'] + # poly to box + x1 = np.min(np.array(poly)[:, 0]) + y1 = np.min(np.array(poly)[:, 1]) + x2 = np.max(np.array(poly)[:, 0]) + y2 = np.max(np.array(poly)[:, 1]) + eps = 1e-5 + if x2 - x1 > eps and y2 - y1 > eps: + clean_box = [ + round(float(x), 3) for x in [x1, y1, x2, y2] + ] + anno = {'clean_box': clean_box, 'gt_cls':int(anno['cls'])} + bboxes.append(anno) + else: + logger.info("invalid box") + + num_bbox = len(bboxes) + if num_bbox <= 0: + continue + + gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32) + gt_class = np.zeros((num_bbox, 1), dtype=np.int32) + is_crowd = np.zeros((num_bbox, 1), dtype=np.int32) + # gt_poly = [None] * num_bbox + + for i, box in enumerate(bboxes): + gt_class[i][0] = box['gt_cls'] + gt_bbox[i, :] = box['clean_box'] + is_crowd[i][0] = 0 + + gt_rec = { + 'is_crowd': is_crowd, + 'gt_class': gt_class, + 'gt_bbox': gt_bbox, + # 'gt_poly': gt_poly, + } + + for k, v in gt_rec.items(): + if k in self.data_fields: + coco_rec[k] = v + + records.append(coco_rec) + ct += 1 + if self.sample_num > 0 and ct >= self.sample_num: + break + self.roidbs = records +``` + +**启动训练** + +启动单卡训练的命令为: +``` +!python3 tools/train.py -c configs/ppyolo/ppyolo_mbv3_large.yml --eval + +# 分布式训练命令为: +!python3 -m paddle.distributed.launch --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/ppyolo/ppyolo_mbv3_large.yml --eval +``` + +训练完成后,日志中会打印模型的精度: + +``` +[07/05 11:42:09] ppdet.engine INFO: Eval iter: 0 +[07/05 11:42:14] ppdet.metrics.metrics INFO: Accumulating evaluatation results... +[07/05 11:42:14] ppdet.metrics.metrics INFO: mAP(0.50, 11point) = 99.31% +[07/05 11:42:14] ppdet.engine INFO: Total sample number: 112, averge FPS: 26.45840794253432 +[07/05 11:42:14] ppdet.engine INFO: Best test bbox ap is 0.996. +``` + + +我们可以使用训练好的模型观察预测结果: +``` +!python3 tools/infer.py -c configs/ppyolo/ppyolo_mbv3_large.yml -o weights=./output/ppyolo_mbv3_large/model_final.pdparams --img_dir=./test.jpg +``` +预测结果如下: + +![](https://ai-studio-static-online.cdn.bcebos.com/0f650c032b0f4d56bd639713924768cc820635e9977845008d233f465291a29e) + +# 5. 印章文字识别实践 + +在使用ppyolo检测到印章区域后,接下来借助PaddleOCR里的文字识别能力,完成印章中文字的识别。 + +PaddleOCR中的OCR算法包含文字检测算法,文字识别算法以及OCR端对端算法。 + +文字检测算法负责检测到图像中的文字,再由文字识别模型识别出检测到的文字,进而实现OCR的任务。文字检测+文字识别串联完成OCR任务的架构称为两阶段的OCR算法。相对应的端对端的OCR方法可以用一个算法同时完成文字检测和识别的任务。 + + +| 文字检测 | 文字识别 | 端对端算法 | +| -------- | -------- | -------- | +| DB\DB++\EAST\SAST\PSENet | SVTR\CRNN\NRTN\Abinet\SAR\... | PGNet | + + +本节中将分别介绍端对端的文字检测识别算法以及两阶段的文字检测识别算法在印章检测识别任务上的实践。 + + +## 5.1 端对端印章文字识别实践 + +本节介绍使用PaddleOCR里的PGNet算法完成印章文字识别。 + +PGNet属于端对端的文字检测识别算法,在PaddleOCR中的配置文件为: +[PaddleOCR/configs/e2e/e2e_r50_vd_pg.yml](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.6/configs/e2e/e2e_r50_vd_pg.yml) + +使用PGNet完成文字检测识别任务的步骤为: +- 修改配置文件 +- 启动训练 + +PGNet默认配置文件的数据路径为totaltext数据集路径,本次训练中,需要修改为上一节数据处理后得到的标签文件和数据目录: + +训练数据配置修改后如下: +``` +Train: + dataset: + name: PGDataSet + data_dir: ./train_data/seal_ppocr + label_file_list: [./train_data/seal_ppocr/seal_ppocr_img.txt] + ratio_list: [1.0] +``` +测试数据集配置修改后如下: +``` +Eval: + dataset: + name: PGDataSet + data_dir: ./train_data/seal_ppocr_test + label_file_list: [./train_data/seal_ppocr_test/seal_ppocr_img.txt] +``` + +启动训练的命令为: +``` +!python3 tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml +``` +模型训练完成后,可以得到最终的精度为47.4%。数据量较少,以及数据质量较差会影响模型的训练精度,如果有更多的数据参与训练,精度将进一步提升。 + +如需获取已训练模型,请扫文末的二维码填写问卷,加入PaddleOCR官方交流群获取全部OCR垂类模型下载链接、《动手学OCR》电子书等全套OCR学习资料🎁 + +## 5.2 两阶段印章文字识别实践 + +上一节介绍了使用PGNet实现印章识别任务的训练流程。本小节将介绍使用PaddleOCR里的文字检测和文字识别算法分别完成印章文字的检测和识别。 + +### 5.2.1 印章文字检测 + +PaddleOCR中包含丰富的文字检测算法,包含DB,DB++,EAST,SAST,PSENet等等。其中DB,DB++,PSENet均支持弯曲文字检测,本项目中,使用DB++作为印章弯曲文字检测算法。 + +PaddleOCR中发布的db++文字检测算法模型是英文文本检测模型,因此需要重新训练模型。 + + +修改[DB++配置文件](DB++的默认配置文件位于[configs/det/det_r50_db++_icdar15.yml](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.6/configs/det/det_r50_db%2B%2B_icdar15.yml) +中的数据路径: + + +``` +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/seal_ppocr + label_file_list: [./train_data/seal_ppocr/seal_ppocr_img.txt] + ratio_list: [1.0] +``` +测试数据集配置修改后如下: +``` +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data/seal_ppocr_test + label_file_list: [./train_data/seal_ppocr_test/seal_ppocr_img.txt] +``` + + +启动训练: +``` +!python3 tools/train.py -c configs/det/det_r50_db++_icdar15.yml -o Global.epoch_num=100 +``` + +考虑到数据较少,通过Global.epoch_num设置仅训练100个epoch。 +模型训练完成后,在测试集上预测的可视化效果如下: + +![](https://ai-studio-static-online.cdn.bcebos.com/498119182f0a414ab86ae2de752fa31c9ddc3a74a76847049cc57884602cb269) + + +如需获取已训练模型,请扫文末的二维码填写问卷,加入PaddleOCR官方交流群获取全部OCR垂类模型下载链接、《动手学OCR》电子书等全套OCR学习资料🎁 + + +### 5.2.2 印章文字识别 + +上一节中完成了印章文字的检测模型训练,本节介绍印章文字识别模型的训练。识别模型采用SVTR算法,SVTR算法是IJCAI收录的文字识别算法,SVTR模型具备超轻量高精度的特点。 + +在启动训练之前,需要准备印章文字识别需要的数据集,需要使用如下代码,将印章中的文字区域剪切出来构建训练集。 + +``` +import cv2 +import numpy as np + +def get_rotate_crop_image(img, points): + ''' + img_height, img_width = img.shape[0:2] + left = int(np.min(points[:, 0])) + right = int(np.max(points[:, 0])) + top = int(np.min(points[:, 1])) + bottom = int(np.max(points[:, 1])) + img_crop = img[top:bottom, left:right, :].copy() + points[:, 0] = points[:, 0] - left + points[:, 1] = points[:, 1] - top + ''' + assert len(points) == 4, "shape of points must be 4*2" + img_crop_width = int( + max( + np.linalg.norm(points[0] - points[1]), + np.linalg.norm(points[2] - points[3]))) + img_crop_height = int( + max( + np.linalg.norm(points[0] - points[3]), + np.linalg.norm(points[1] - points[2]))) + pts_std = np.float32([[0, 0], [img_crop_width, 0], + [img_crop_width, img_crop_height], + [0, img_crop_height]]) + M = cv2.getPerspectiveTransform(points, pts_std) + dst_img = cv2.warpPerspective( + img, + M, (img_crop_width, img_crop_height), + borderMode=cv2.BORDER_REPLICATE, + flags=cv2.INTER_CUBIC) + dst_img_height, dst_img_width = dst_img.shape[0:2] + if dst_img_height * 1.0 / dst_img_width >= 1.5: + dst_img = np.rot90(dst_img) + return dst_img + + +def run(data_dir, label_file, save_dir): + datas = open(label_file, 'r').readlines() + for idx, line in enumerate(datas): + img_path, label = line.strip().split('\t') + img_path = os.path.join(data_dir, img_path) + + label = json.loads(label) + src_im = cv2.imread(img_path) + if src_im is None: + continue + + for anno in label: + seal_box = anno['seal_box'] + txt_boxes = anno['polys'] + crop_im = get_rotate_crop_image(src_im, text_boxes) + + save_path = os.path.join(save_dir, f'{idx}.png') + if not os.path.exists(save_dir): + os.makedirs(save_dir) + # print(src_im.shape) + cv2.imwrite(save_path, crop_im) + +``` + + +数据处理完成后,即可配置训练的配置文件。SVTR配置文件选择[configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.6/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml) +修改SVTR配置文件中的训练数据部分如下: + +``` +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/seal_ppocr_crop/ + label_file_list: + - ./train_data/seal_ppocr_crop/train_list.txt +``` + +修改预测部分配置文件: +``` +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/seal_ppocr_crop/ + label_file_list: + - ./train_data/seal_ppocr_crop_test/train_list.txt +``` + +启动训练: + +``` +!python3 tools/train.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml + +``` + +训练完成后可以发现测试集指标达到了61%。 +由于数据较少,训练时会发现在训练集上的acc指标远大于测试集上的acc指标,即出现过拟合现象。通过补充数据和一些数据增强可以缓解这个问题。 + + + +如需获取已训练模型,请扫下图二维码填写问卷,加入PaddleOCR官方交流群获取全部OCR垂类模型下载链接、《动手学OCR》电子书等全套OCR学习资料🎁 + + +![](https://ai-studio-static-online.cdn.bcebos.com/ea32877b717643289dc2121a2e573526d99d0f9eecc64ad4bd8dcf121cb5abde) diff --git "a/applications/\345\217\221\347\245\250\345\205\263\351\224\256\344\277\241\346\201\257\346\212\275\345\217\226.md" "b/applications/\345\217\221\347\245\250\345\205\263\351\224\256\344\277\241\346\201\257\346\212\275\345\217\226.md" index 14a6a1c8f1dd2350767afa162063b06791e79dd4..82f5b8d48600c6bebb4d3183ee801305d305d531 100644 --- "a/applications/\345\217\221\347\245\250\345\205\263\351\224\256\344\277\241\346\201\257\346\212\275\345\217\226.md" +++ "b/applications/\345\217\221\347\245\250\345\205\263\351\224\256\344\277\241\346\201\257\346\212\275\345\217\226.md" @@ -30,7 +30,7 @@ cd PaddleOCR # 安装PaddleOCR的依赖 pip install -r requirements.txt # 安装关键信息抽取任务的依赖 -pip install -r ./ppstructure/vqa/requirements.txt +pip install -r ./ppstructure/kie/requirements.txt ``` ## 4. 关键信息抽取 @@ -94,7 +94,7 @@ VI-LayoutXLM的配置为[ser_vi_layoutxlm_xfund_zh_udml.yml](../configs/kie/vi_l ```yml Architecture: - model_type: &model_type "vqa" + model_type: &model_type "kie" name: DistillationModel algorithm: Distillation Models: @@ -177,7 +177,7 @@ python3 tools/eval.py -c ./fapiao/ser_vi_layoutxlm.yml -o Architecture.Backbone. 使用下面的命令进行预测。 ```bash -python3 tools/infer_vqa_token_ser.py -c fapiao/ser_vi_layoutxlm.yml -o Architecture.Backbone.checkpoints=fapiao/models/ser_vi_layoutxlm_fapiao_udml/best_accuracy Global.infer_img=./train_data/XFUND/zh_val/val.json Global.infer_mode=False +python3 tools/infer_kie_token_ser.py -c fapiao/ser_vi_layoutxlm.yml -o Architecture.Backbone.checkpoints=fapiao/models/ser_vi_layoutxlm_fapiao_udml/best_accuracy Global.infer_img=./train_data/XFUND/zh_val/val.json Global.infer_mode=False ``` 预测结果会保存在配置文件中的`Global.save_res_path`目录中。 @@ -195,7 +195,7 @@ python3 tools/infer_vqa_token_ser.py -c fapiao/ser_vi_layoutxlm.yml -o Architect ```bash -python3 tools/infer_vqa_token_ser.py -c fapiao/ser_vi_layoutxlm.yml -o Architecture.Backbone.checkpoints=fapiao/models/ser_vi_layoutxlm_fapiao_udml/best_accuracy Global.infer_img=./train_data/zzsfp/imgs/b25.jpg Global.infer_mode=True +python3 tools/infer_kie_token_ser.py -c fapiao/ser_vi_layoutxlm.yml -o Architecture.Backbone.checkpoints=fapiao/models/ser_vi_layoutxlm_fapiao_udml/best_accuracy Global.infer_img=./train_data/zzsfp/imgs/b25.jpg Global.infer_mode=True ``` 结果如下所示。 @@ -211,7 +211,7 @@ python3 tools/infer_vqa_token_ser.py -c fapiao/ser_vi_layoutxlm.yml -o Architect 如果希望构建基于你在垂类场景训练得到的OCR检测与识别模型,可以使用下面的方法传入检测与识别的inference 模型路径,即可完成OCR文本检测与识别以及SER的串联过程。 ```bash -python3 tools/infer_vqa_token_ser.py -c fapiao/ser_vi_layoutxlm.yml -o Architecture.Backbone.checkpoints=fapiao/models/ser_vi_layoutxlm_fapiao_udml/best_accuracy Global.infer_img=./train_data/zzsfp/imgs/b25.jpg Global.infer_mode=True Global.kie_rec_model_dir="your_rec_model" Global.kie_det_model_dir="your_det_model" +python3 tools/infer_kie_token_ser.py -c fapiao/ser_vi_layoutxlm.yml -o Architecture.Backbone.checkpoints=fapiao/models/ser_vi_layoutxlm_fapiao_udml/best_accuracy Global.infer_img=./train_data/zzsfp/imgs/b25.jpg Global.infer_mode=True Global.kie_rec_model_dir="your_rec_model" Global.kie_det_model_dir="your_det_model" ``` ### 4.4 关系抽取(Relation Extraction) @@ -316,7 +316,7 @@ python3 tools/eval.py -c ./fapiao/re_vi_layoutxlm.yml -o Architecture.Backbone.c # -o 后面的字段是RE任务的配置 # -c_ser 后面的是SER任务的配置文件 # -c_ser 后面的字段是SER任务的配置 -python3 tools/infer_vqa_token_ser_re.py -c fapiao/re_vi_layoutxlm.yml -o Architecture.Backbone.checkpoints=fapiao/models/re_vi_layoutxlm_fapiao_udml/best_accuracy Global.infer_img=./train_data/zzsfp/val.json Global.infer_mode=False -c_ser fapiao/ser_vi_layoutxlm.yml -o_ser Architecture.Backbone.checkpoints=fapiao/models/ser_vi_layoutxlm_fapiao_udml/best_accuracy +python3 tools/infer_kie_token_ser_re.py -c fapiao/re_vi_layoutxlm.yml -o Architecture.Backbone.checkpoints=fapiao/models/re_vi_layoutxlm_fapiao_trained/best_accuracy Global.infer_img=./train_data/zzsfp/val.json Global.infer_mode=False -c_ser fapiao/ser_vi_layoutxlm.yml -o_ser Architecture.Backbone.checkpoints=fapiao/models/ser_vi_layoutxlm_fapiao_trained/best_accuracy ``` 预测结果会保存在配置文件中的`Global.save_res_path`目录中。 @@ -333,11 +333,11 @@ python3 tools/infer_vqa_token_ser_re.py -c fapiao/re_vi_layoutxlm.yml -o Archite 如果希望使用OCR引擎结果得到的结果进行推理,则可以使用下面的命令进行推理。 ```bash -python3 tools/infer_vqa_token_ser_re.py -c fapiao/re_vi_layoutxlm.yml -o Architecture.Backbone.checkpoints=fapiao/models/re_vi_layoutxlm_fapiao_udml/best_accuracy Global.infer_img=./train_data/zzsfp/val.json Global.infer_mode=True -c_ser fapiao/ser_vi_layoutxlm.yml -o_ser Architecture.Backbone.checkpoints=fapiao/models/ser_vi_layoutxlm_fapiao_udml/best_accuracy +python3 tools/infer_kie_token_ser_re.py -c fapiao/re_vi_layoutxlm.yml -o Architecture.Backbone.checkpoints=fapiao/models/re_vi_layoutxlm_fapiao_udml/best_accuracy Global.infer_img=./train_data/zzsfp/val.json Global.infer_mode=True -c_ser fapiao/ser_vi_layoutxlm.yml -o_ser Architecture.Backbone.checkpoints=fapiao/models/ser_vi_layoutxlm_fapiao_udml/best_accuracy ``` 如果希望构建基于你在垂类场景训练得到的OCR检测与识别模型,可以使用下面的方法传入,即可完成SER + RE的串联过程。 ```bash -python3 tools/infer_vqa_token_ser_re.py -c fapiao/re_vi_layoutxlm.yml -o Architecture.Backbone.checkpoints=fapiao/models/re_vi_layoutxlm_fapiao_udml/best_accuracy Global.infer_img=./train_data/zzsfp/val.json Global.infer_mode=True -c_ser fapiao/ser_vi_layoutxlm.yml -o_ser Architecture.Backbone.checkpoints=fapiao/models/ser_vi_layoutxlm_fapiao_udml/best_accuracy Global.kie_rec_model_dir="your_rec_model" Global.kie_det_model_dir="your_det_model" +python3 tools/infer_kie_token_ser_re.py -c fapiao/re_vi_layoutxlm.yml -o Architecture.Backbone.checkpoints=fapiao/models/re_vi_layoutxlm_fapiao_udml/best_accuracy Global.infer_img=./train_data/zzsfp/val.json Global.infer_mode=True -c_ser fapiao/ser_vi_layoutxlm.yml -o_ser Architecture.Backbone.checkpoints=fapiao/models/ser_vi_layoutxlm_fapiao_udml/best_accuracy Global.kie_rec_model_dir="your_rec_model" Global.kie_det_model_dir="your_det_model" ``` diff --git "a/applications/\345\277\253\351\200\237\346\236\204\345\273\272\345\215\241\350\257\201\347\261\273OCR.md" "b/applications/\345\277\253\351\200\237\346\236\204\345\273\272\345\215\241\350\257\201\347\261\273OCR.md" new file mode 100644 index 0000000000000000000000000000000000000000..ab9ddf1bbd20e538f0b1e3d43e9a5af5383487b4 --- /dev/null +++ "b/applications/\345\277\253\351\200\237\346\236\204\345\273\272\345\215\241\350\257\201\347\261\273OCR.md" @@ -0,0 +1,782 @@ +# 快速构建卡证类OCR + + +- [快速构建卡证类OCR](#快速构建卡证类ocr) + - [1. 金融行业卡证识别应用](#1-金融行业卡证识别应用) + - [1.1 金融行业中的OCR相关技术](#11-金融行业中的ocr相关技术) + - [1.2 金融行业中的卡证识别场景介绍](#12-金融行业中的卡证识别场景介绍) + - [1.3 OCR落地挑战](#13-ocr落地挑战) + - [2. 卡证识别技术解析](#2-卡证识别技术解析) + - [2.1 卡证分类模型](#21-卡证分类模型) + - [2.2 卡证识别模型](#22-卡证识别模型) + - [3. OCR技术拆解](#3-ocr技术拆解) + - [3.1技术流程](#31技术流程) + - [3.2 OCR技术拆解---卡证分类](#32-ocr技术拆解---卡证分类) + - [卡证分类:数据、模型准备](#卡证分类数据模型准备) + - [卡证分类---修改配置文件](#卡证分类---修改配置文件) + - [卡证分类---训练](#卡证分类---训练) + - [3.2 OCR技术拆解---卡证识别](#32-ocr技术拆解---卡证识别) + - [身份证识别:检测+分类](#身份证识别检测分类) + - [数据标注](#数据标注) + - [4 . 项目实践](#4--项目实践) + - [4.1 环境准备](#41-环境准备) + - [4.2 配置文件修改](#42-配置文件修改) + - [4.3 代码修改](#43-代码修改) + - [4.3.1 数据读取](#431-数据读取) + - [4.3.2 head修改](#432--head修改) + - [4.3.3 修改loss](#433-修改loss) + - [4.3.4 后处理](#434-后处理) + - [4.4. 模型启动](#44-模型启动) + - [5 总结](#5-总结) + - [References](#references) + +## 1. 金融行业卡证识别应用 + +### 1.1 金融行业中的OCR相关技术 + +* 《“十四五”数字经济发展规划》指出,2020年我国数字经济核心产业增加值占GDP比重达7.8%,随着数字经济迈向全面扩展,到2025年该比例将提升至10%。 + +* 在过去数年的跨越发展与积累沉淀中,数字金融、金融科技已在对金融业的重塑与再造中充分印证了其自身价值。 + +* 以智能为目标,提升金融数字化水平,实现业务流程自动化,降低人力成本。 + + +![](https://ai-studio-static-online.cdn.bcebos.com/8bb381f164c54ea9b4043cf66fc92ffdea8aaf851bab484fa6e19bd2f93f154f) + + + +### 1.2 金融行业中的卡证识别场景介绍 + +应用场景:身份证、银行卡、营业执照、驾驶证等。 + +应用难点:由于数据的采集来源多样,以及实际采集数据各种噪声:反光、褶皱、模糊、倾斜等各种问题干扰。 + +![](https://ai-studio-static-online.cdn.bcebos.com/981640e17d05487e961162f8576c9e11634ca157f79048d4bd9d3bc21722afe8) + + + +### 1.3 OCR落地挑战 + + +![](https://ai-studio-static-online.cdn.bcebos.com/a5973a8ddeff4bd7ac082f02dc4d0c79de21e721b41641cbb831f23c2cb8fce2) + + + + + +## 2. 卡证识别技术解析 + + +![](https://ai-studio-static-online.cdn.bcebos.com/d7f96effc2434a3ca2d4144ff33c50282b830670c892487d8d7dec151921cce7) + + +### 2.1 卡证分类模型 + +卡证分类:基于PPLCNet + +与其他轻量级模型相比在CPU环境下ImageNet数据集上的表现 + +![](https://ai-studio-static-online.cdn.bcebos.com/cbda3390cb994f98a3c8a9ba88c90c348497763f6c9f4b4797f7d63d84da5f63) + +![](https://ai-studio-static-online.cdn.bcebos.com/dedab7b7fd6543aa9e7f625132b24e3ba3f200e361fa468dac615f7814dfb98d) + + + +* 模型来自模型库PaddleClas,它是一个图像识别和图像分类任务的工具集,助力使用者训练出更好的视觉模型和应用落地。 + + +![](https://ai-studio-static-online.cdn.bcebos.com/606d1afaf0d0484a99b1d39895d394b22f24e74591514796859a9ea3a2799b78) + + + +### 2.2 卡证识别模型 + +* 检测:DBNet 识别:SVRT + +![](https://ai-studio-static-online.cdn.bcebos.com/9a7a4e19edc24310b46620f2ee7430f918223b93d4f14a15a52973c096926bad) + + +* PPOCRv3在文本检测、识别进行了一系列改进优化,在保证精度的同时提升预测效率 + + +![](https://ai-studio-static-online.cdn.bcebos.com/6afdbb77e8db4aef9b169e4e94c5d90a9764cfab4f2c4c04aa9afdf4f54d7680) + + +![](https://ai-studio-static-online.cdn.bcebos.com/c1a7d197847a4f168848c59b8e625d1d5e8066b778144395a8b9382bb85dc364) + + +## 3. OCR技术拆解 + +### 3.1技术流程 + +![](https://ai-studio-static-online.cdn.bcebos.com/89ba046177864d8783ced6cb31ba92a66ca2169856a44ee59ac2bb18e44a6c4b) + + +### 3.2 OCR技术拆解---卡证分类 + +#### 卡证分类:数据、模型准备 + + +A 使用爬虫获取无标注数据,将相同类别的放在同一文件夹下,文件名从0开始命名。具体格式如下图所示。 + +​ 注:卡证类数据,建议每个类别数据量在500张以上 +![](https://ai-studio-static-online.cdn.bcebos.com/6f875b6e695e4fe5aedf427beb0d4ce8064ad7cc33c44faaad59d3eb9732639d) + + +B 一行命令生成标签文件 + +``` +tree -r -i -f | grep -E "jpg|JPG|jpeg|JPEG|png|PNG|webp" | awk -F "/" '{print $0" "$2}' > train_list.txt +``` + +C [下载预训练模型 ](https://github.com/PaddlePaddle/PaddleClas/blob/release/2.4/docs/zh_CN/models/PP-LCNet.md) + + + +#### 卡证分类---修改配置文件 + + +配置文件主要修改三个部分: + + 全局参数:预训练模型路径/训练轮次/图像尺寸 + + 模型结构:分类数 + + 数据处理:训练/评估数据路径 + + + ![](https://ai-studio-static-online.cdn.bcebos.com/e0dc05039c7444c5ab1260ff550a408748df8d4cfe864223adf390e51058dbd5) + +#### 卡证分类---训练 + + +指定配置文件启动训练: + +``` +!python /home/aistudio/work/PaddleClas/tools/train.py -c /home/aistudio/work/PaddleClas/ppcls/configs/PULC/text_image_orientation/PPLCNet_x1_0.yaml +``` +![](https://ai-studio-static-online.cdn.bcebos.com/06af09bde845449ba0a676410f4daa1cdc3983ac95034bdbbafac3b7fd94042f) + +​ 注:日志中显示了训练结果和评估结果(训练时可以设置固定轮数评估一次) + + +### 3.2 OCR技术拆解---卡证识别 + +卡证识别(以身份证检测为例) +存在的困难及问题: + + * 在自然场景下,由于各种拍摄设备以及光线、角度不同等影响导致实际得到的证件影像千差万别。 + + * 如何快速提取需要的关键信息 + + * 多行的文本信息,检测结果如何正确拼接 + + ![](https://ai-studio-static-online.cdn.bcebos.com/4f8f5533a2914e0a821f4a639677843c32ec1f08a1b1488d94c0b8bfb6e72d2d) + + + +* OCR技术拆解---OCR工具库 + + PaddleOCR是一个丰富、领先且实用的OCR工具库,助力开发者训练出更好的模型并应用落地 + +![](https://ai-studio-static-online.cdn.bcebos.com/16c5e16d53b8428c95129cac4f5520204d869910247943e494d854227632e882) + + +身份证识别:用现有的方法识别 + +![](https://ai-studio-static-online.cdn.bcebos.com/12d402e6a06d482a88f979e0ebdfb39f4d3fc8b80517499689ec607ddb04fbf3) + + + + +#### 身份证识别:检测+分类 + +> 方法:基于现有的dbnet检测模型,加入分类方法。检测同时进行分类,从一定程度上优化识别流程 + +![](https://ai-studio-static-online.cdn.bcebos.com/e1e798c87472477fa0bfca0da12bb0c180845a3e167a4761b0d26ff4330a5ccb) + + +![](https://ai-studio-static-online.cdn.bcebos.com/23a5a19c746441309864586e467f995ec8a551a3661640e493fc4d77520309cd) + +#### 数据标注 + +使用PaddleOCRLable进行快速标注 + +![](https://ai-studio-static-online.cdn.bcebos.com/a73180425fa14f919ce52d9bf70246c3995acea1831843cca6c17d871b8f5d95) + + +* 修改PPOCRLabel.py,将下图中的kie参数设置为True + + +![](https://ai-studio-static-online.cdn.bcebos.com/d445cf4d850e4063b9a7fc6a075c12204cf912ff23ec471fa2e268b661b3d693) + + +* 数据标注踩坑分享 + +![](https://ai-studio-static-online.cdn.bcebos.com/89f42eccd600439fa9e28c97ccb663726e4e54ce3a854825b4c3b7d554ea21df) + +​ 注:两者只有标注有差别,训练参数数据集都相同 + +## 4 . 项目实践 + +AIStudio项目链接:[快速构建卡证类OCR](https://aistudio.baidu.com/aistudio/projectdetail/4459116) + +### 4.1 环境准备 + +1)拉取[paddleocr](https://github.com/PaddlePaddle/PaddleOCR)项目,如果从github上拉取速度慢可以选择从gitee上获取。 +``` +!git clone https://github.com/PaddlePaddle/PaddleOCR.git -b release/2.6 /home/aistudio/work/ +``` + +2)获取并解压预训练模型,如果要使用其他模型可以从模型库里自主选择合适模型。 +``` +!wget -P work/pre_trained/ https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_distill_train.tar +!tar -vxf /home/aistudio/work/pre_trained/ch_PP-OCRv3_det_distill_train.tar -C /home/aistudio/work/pre_trained +``` +3) 安装必要依赖 +``` +!pip install -r /home/aistudio/work/requirements.txt +``` + +### 4.2 配置文件修改 + +修改配置文件 *work/configs/det/detmv3db.yml* + +具体修改说明如下: + +![](https://ai-studio-static-online.cdn.bcebos.com/fcdf517af5a6466294d72db7450209378d8efd9b77764e329d3f2aff3579a20c) + + 注:在上述的配置文件的Global变量中需要添加以下两个参数: + +​ label_list 为标签表 +​ num_classes 为分类数 +​ 上述两个参数根据实际的情况配置即可 + + +![](https://ai-studio-static-online.cdn.bcebos.com/0b056be24f374812b61abf43305774767ae122c8479242f98aa0799b7bfc81d4) + +其中lable_list内容如下例所示,***建议第一个参数设置为 background,不要设置为实际要提取的关键信息种类***: + +![](https://ai-studio-static-online.cdn.bcebos.com/9fc78bbcdf754898b9b2c7f000ddf562afac786482ab4f2ab063e2242faa542a) + +配置文件中的其他设置说明 + +![](https://ai-studio-static-online.cdn.bcebos.com/c7fc5e631dd44bc8b714630f4e49d9155a831d9e56c64e2482ded87081d0db22) + +![](https://ai-studio-static-online.cdn.bcebos.com/8d1022ac25d9474daa4fb236235bd58760039d58ad46414f841559d68e0d057f) + +![](https://ai-studio-static-online.cdn.bcebos.com/ee927ad9ebd442bb96f163a7ebbf4bc95e6bedee97324a51887cf82de0851fd3) + + + + +### 4.3 代码修改 + + +#### 4.3.1 数据读取 + + + +* 修改 PaddleOCR/ppocr/data/imaug/label_ops.py中的DetLabelEncode + + +```python +class DetLabelEncode(object): + + # 修改检测标签的编码处,新增了参数分类数:num_classes,重写初始化方法,以及分类标签的读取 + + def __init__(self, label_list, num_classes=8, **kwargs): + self.num_classes = num_classes + self.label_list = [] + if label_list: + if isinstance(label_list, str): + with open(label_list, 'r+', encoding='utf-8') as f: + for line in f.readlines(): + self.label_list.append(line.replace("\n", "")) + else: + self.label_list = label_list + else: + assert ' please check label_list whether it is none or config is right' + + if num_classes != len(self.label_list): # 校验分类数和标签的一致性 + assert 'label_list length is not equal to the num_classes' + + def __call__(self, data): + label = data['label'] + label = json.loads(label) + nBox = len(label) + boxes, txts, txt_tags, classes = [], [], [], [] + for bno in range(0, nBox): + box = label[bno]['points'] + txt = label[bno]['key_cls'] # 此处将kie中的参数作为分类读取 + boxes.append(box) + txts.append(txt) + + if txt in ['*', '###']: + txt_tags.append(True) + if self.num_classes > 1: + classes.append(-2) + else: + txt_tags.append(False) + if self.num_classes > 1: # 将KIE内容的key标签作为分类标签使用 + classes.append(int(self.label_list.index(txt))) + + if len(boxes) == 0: + + return None + boxes = self.expand_points_num(boxes) + boxes = np.array(boxes, dtype=np.float32) + txt_tags = np.array(txt_tags, dtype=np.bool) + classes = classes + data['polys'] = boxes + data['texts'] = txts + data['ignore_tags'] = txt_tags + if self.num_classes > 1: + data['classes'] = classes + return data +``` + +* 修改 PaddleOCR/ppocr/data/imaug/make_shrink_map.py中的MakeShrinkMap类。这里需要注意的是,如果我们设置的label_list中的第一个参数为要检测的信息那么会得到如下的mask, + +举例说明: +这是检测的mask图,图中有四个mask那么实际对应的分类应该是4类 + +![](https://ai-studio-static-online.cdn.bcebos.com/42d2188d3d6b498880952e12c3ceae1efabf135f8d9f4c31823f09ebe02ba9d2) + + + +label_list中第一个为关键分类,则得到的分类Mask实际如下,与上图相比,少了一个box: + +![](https://ai-studio-static-online.cdn.bcebos.com/864604967256461aa7c5d32cd240645e9f4c70af773341d5911f22d5a3e87b5f) + + + +```python +class MakeShrinkMap(object): + r''' + Making binary mask from detection data with ICDAR format. + Typically following the process of class `MakeICDARData`. + ''' + + def __init__(self, min_text_size=8, shrink_ratio=0.4, num_classes=8, **kwargs): + self.min_text_size = min_text_size + self.shrink_ratio = shrink_ratio + self.num_classes = num_classes # 添加了分类 + + def __call__(self, data): + image = data['image'] + text_polys = data['polys'] + ignore_tags = data['ignore_tags'] + if self.num_classes > 1: + classes = data['classes'] + + h, w = image.shape[:2] + text_polys, ignore_tags = self.validate_polygons(text_polys, + ignore_tags, h, w) + gt = np.zeros((h, w), dtype=np.float32) + mask = np.ones((h, w), dtype=np.float32) + gt_class = np.zeros((h, w), dtype=np.float32) # 新增分类 + for i in range(len(text_polys)): + polygon = text_polys[i] + height = max(polygon[:, 1]) - min(polygon[:, 1]) + width = max(polygon[:, 0]) - min(polygon[:, 0]) + if ignore_tags[i] or min(height, width) < self.min_text_size: + cv2.fillPoly(mask, + polygon.astype(np.int32)[np.newaxis, :, :], 0) + ignore_tags[i] = True + else: + polygon_shape = Polygon(polygon) + subject = [tuple(l) for l in polygon] + padding = pyclipper.PyclipperOffset() + padding.AddPath(subject, pyclipper.JT_ROUND, + pyclipper.ET_CLOSEDPOLYGON) + shrinked = [] + + # Increase the shrink ratio every time we get multiple polygon returned back + possible_ratios = np.arange(self.shrink_ratio, 1, + self.shrink_ratio) + np.append(possible_ratios, 1) + for ratio in possible_ratios: + distance = polygon_shape.area * ( + 1 - np.power(ratio, 2)) / polygon_shape.length + shrinked = padding.Execute(-distance) + if len(shrinked) == 1: + break + + if shrinked == []: + cv2.fillPoly(mask, + polygon.astype(np.int32)[np.newaxis, :, :], 0) + ignore_tags[i] = True + continue + + for each_shirnk in shrinked: + shirnk = np.array(each_shirnk).reshape(-1, 2) + cv2.fillPoly(gt, [shirnk.astype(np.int32)], 1) + if self.num_classes > 1: # 绘制分类的mask + cv2.fillPoly(gt_class, polygon.astype(np.int32)[np.newaxis, :, :], classes[i]) + + + data['shrink_map'] = gt + + if self.num_classes > 1: + data['class_mask'] = gt_class + + data['shrink_mask'] = mask + return data +``` + +由于在训练数据中会对数据进行resize设置,yml中的操作为:EastRandomCropData,所以需要修改PaddleOCR/ppocr/data/imaug/random_crop_data.py中的EastRandomCropData + + +```python +class EastRandomCropData(object): + def __init__(self, + size=(640, 640), + max_tries=10, + min_crop_side_ratio=0.1, + keep_ratio=True, + num_classes=8, + **kwargs): + self.size = size + self.max_tries = max_tries + self.min_crop_side_ratio = min_crop_side_ratio + self.keep_ratio = keep_ratio + self.num_classes = num_classes + + def __call__(self, data): + img = data['image'] + text_polys = data['polys'] + ignore_tags = data['ignore_tags'] + texts = data['texts'] + if self.num_classes > 1: + classes = data['classes'] + all_care_polys = [ + text_polys[i] for i, tag in enumerate(ignore_tags) if not tag + ] + # 计算crop区域 + crop_x, crop_y, crop_w, crop_h = crop_area( + img, all_care_polys, self.min_crop_side_ratio, self.max_tries) + # crop 图片 保持比例填充 + scale_w = self.size[0] / crop_w + scale_h = self.size[1] / crop_h + scale = min(scale_w, scale_h) + h = int(crop_h * scale) + w = int(crop_w * scale) + if self.keep_ratio: + padimg = np.zeros((self.size[1], self.size[0], img.shape[2]), + img.dtype) + padimg[:h, :w] = cv2.resize( + img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h)) + img = padimg + else: + img = cv2.resize( + img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], + tuple(self.size)) + # crop 文本框 + text_polys_crop = [] + ignore_tags_crop = [] + texts_crop = [] + classes_crop = [] + for poly, text, tag,class_index in zip(text_polys, texts, ignore_tags,classes): + poly = ((poly - (crop_x, crop_y)) * scale).tolist() + if not is_poly_outside_rect(poly, 0, 0, w, h): + text_polys_crop.append(poly) + ignore_tags_crop.append(tag) + texts_crop.append(text) + if self.num_classes > 1: + classes_crop.append(class_index) + data['image'] = img + data['polys'] = np.array(text_polys_crop) + data['ignore_tags'] = ignore_tags_crop + data['texts'] = texts_crop + if self.num_classes > 1: + data['classes'] = classes_crop + return data +``` + +#### 4.3.2 head修改 + + + +主要修改 ppocr/modeling/heads/det_db_head.py,将Head类中的最后一层的输出修改为实际的分类数,同时在DBHead中新增分类的head。 + +![](https://ai-studio-static-online.cdn.bcebos.com/0e25da2ccded4af19e95c85c3d3287ab4d53e31a4eed4607b6a4cb637c43f6d3) + + + +#### 4.3.3 修改loss + + +修改PaddleOCR/ppocr/losses/det_db_loss.py中的DBLoss类,分类采用交叉熵损失函数进行计算。 + +![](https://ai-studio-static-online.cdn.bcebos.com/dc10a070018d4d27946c26ec24a2a85bc3f16422f4964f72a9b63c6170d954e1) + + +#### 4.3.4 后处理 + + + +由于涉及到eval以及后续推理能否正常使用,我们需要修改后处理的相关代码,修改位置 PaddleOCR/ppocr/postprocess/db_postprocess.py中的DBPostProcess类 + + +```python +class DBPostProcess(object): + """ + The post process for Differentiable Binarization (DB). + """ + + def __init__(self, + thresh=0.3, + box_thresh=0.7, + max_candidates=1000, + unclip_ratio=2.0, + use_dilation=False, + score_mode="fast", + **kwargs): + self.thresh = thresh + self.box_thresh = box_thresh + self.max_candidates = max_candidates + self.unclip_ratio = unclip_ratio + self.min_size = 3 + self.score_mode = score_mode + assert score_mode in [ + "slow", "fast" + ], "Score mode must be in [slow, fast] but got: {}".format(score_mode) + + self.dilation_kernel = None if not use_dilation else np.array( + [[1, 1], [1, 1]]) + + def boxes_from_bitmap(self, pred, _bitmap, classes, dest_width, dest_height): + """ + _bitmap: single map with shape (1, H, W), + whose values are binarized as {0, 1} + """ + + bitmap = _bitmap + height, width = bitmap.shape + + outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, + cv2.CHAIN_APPROX_SIMPLE) + if len(outs) == 3: + img, contours, _ = outs[0], outs[1], outs[2] + elif len(outs) == 2: + contours, _ = outs[0], outs[1] + + num_contours = min(len(contours), self.max_candidates) + + boxes = [] + scores = [] + class_indexes = [] + class_scores = [] + for index in range(num_contours): + contour = contours[index] + points, sside = self.get_mini_boxes(contour) + if sside < self.min_size: + continue + points = np.array(points) + if self.score_mode == "fast": + score, class_index, class_score = self.box_score_fast(pred, points.reshape(-1, 2), classes) + else: + score, class_index, class_score = self.box_score_slow(pred, contour, classes) + if self.box_thresh > score: + continue + + box = self.unclip(points).reshape(-1, 1, 2) + box, sside = self.get_mini_boxes(box) + if sside < self.min_size + 2: + continue + box = np.array(box) + + box[:, 0] = np.clip( + np.round(box[:, 0] / width * dest_width), 0, dest_width) + box[:, 1] = np.clip( + np.round(box[:, 1] / height * dest_height), 0, dest_height) + + boxes.append(box.astype(np.int16)) + scores.append(score) + + class_indexes.append(class_index) + class_scores.append(class_score) + + if classes is None: + return np.array(boxes, dtype=np.int16), scores + else: + return np.array(boxes, dtype=np.int16), scores, class_indexes, class_scores + + def unclip(self, box): + unclip_ratio = self.unclip_ratio + poly = Polygon(box) + distance = poly.area * unclip_ratio / poly.length + offset = pyclipper.PyclipperOffset() + offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) + expanded = np.array(offset.Execute(distance)) + return expanded + + def get_mini_boxes(self, contour): + bounding_box = cv2.minAreaRect(contour) + points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0]) + + index_1, index_2, index_3, index_4 = 0, 1, 2, 3 + if points[1][1] > points[0][1]: + index_1 = 0 + index_4 = 1 + else: + index_1 = 1 + index_4 = 0 + if points[3][1] > points[2][1]: + index_2 = 2 + index_3 = 3 + else: + index_2 = 3 + index_3 = 2 + + box = [ + points[index_1], points[index_2], points[index_3], points[index_4] + ] + return box, min(bounding_box[1]) + + def box_score_fast(self, bitmap, _box, classes): + ''' + box_score_fast: use bbox mean score as the mean score + ''' + h, w = bitmap.shape[:2] + box = _box.copy() + xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1) + xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int), 0, w - 1) + ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int), 0, h - 1) + ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int), 0, h - 1) + + mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) + box[:, 0] = box[:, 0] - xmin + box[:, 1] = box[:, 1] - ymin + cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1) + + if classes is None: + return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0], None, None + else: + k = 999 + class_mask = np.full((ymax - ymin + 1, xmax - xmin + 1), k, dtype=np.int32) + + cv2.fillPoly(class_mask, box.reshape(1, -1, 2).astype(np.int32), 0) + classes = classes[ymin:ymax + 1, xmin:xmax + 1] + + new_classes = classes + class_mask + a = new_classes.reshape(-1) + b = np.where(a >= k) + classes = np.delete(a, b[0].tolist()) + + class_index = np.argmax(np.bincount(classes)) + class_score = np.sum(classes == class_index) / len(classes) + + return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0], class_index, class_score + + def box_score_slow(self, bitmap, contour, classes): + """ + box_score_slow: use polyon mean score as the mean score + """ + h, w = bitmap.shape[:2] + contour = contour.copy() + contour = np.reshape(contour, (-1, 2)) + + xmin = np.clip(np.min(contour[:, 0]), 0, w - 1) + xmax = np.clip(np.max(contour[:, 0]), 0, w - 1) + ymin = np.clip(np.min(contour[:, 1]), 0, h - 1) + ymax = np.clip(np.max(contour[:, 1]), 0, h - 1) + + mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) + + contour[:, 0] = contour[:, 0] - xmin + contour[:, 1] = contour[:, 1] - ymin + + cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1) + + if classes is None: + return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0], None, None + else: + k = 999 + class_mask = np.full((ymax - ymin + 1, xmax - xmin + 1), k, dtype=np.int32) + + cv2.fillPoly(class_mask, contour.reshape(1, -1, 2).astype(np.int32), 0) + classes = classes[ymin:ymax + 1, xmin:xmax + 1] + + new_classes = classes + class_mask + a = new_classes.reshape(-1) + b = np.where(a >= k) + classes = np.delete(a, b[0].tolist()) + + class_index = np.argmax(np.bincount(classes)) + class_score = np.sum(classes == class_index) / len(classes) + + return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0], class_index, class_score + + def __call__(self, outs_dict, shape_list): + pred = outs_dict['maps'] + if isinstance(pred, paddle.Tensor): + pred = pred.numpy() + pred = pred[:, 0, :, :] + segmentation = pred > self.thresh + + if "classes" in outs_dict: + classes = outs_dict['classes'] + if isinstance(classes, paddle.Tensor): + classes = classes.numpy() + classes = classes[:, 0, :, :] + + else: + classes = None + + boxes_batch = [] + for batch_index in range(pred.shape[0]): + src_h, src_w, ratio_h, ratio_w = shape_list[batch_index] + if self.dilation_kernel is not None: + mask = cv2.dilate( + np.array(segmentation[batch_index]).astype(np.uint8), + self.dilation_kernel) + else: + mask = segmentation[batch_index] + + if classes is None: + boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask, None, + src_w, src_h) + boxes_batch.append({'points': boxes}) + else: + boxes, scores, class_indexes, class_scores = self.boxes_from_bitmap(pred[batch_index], mask, + classes[batch_index], + src_w, src_h) + boxes_batch.append({'points': boxes, "classes": class_indexes, "class_scores": class_scores}) + + return boxes_batch +``` + +### 4.4. 模型启动 + +在完成上述步骤后我们就可以正常启动训练 + +``` +!python /home/aistudio/work/PaddleOCR/tools/train.py -c /home/aistudio/work/PaddleOCR/configs/det/det_mv3_db.yml +``` + +其他命令: +``` +!python /home/aistudio/work/PaddleOCR/tools/eval.py -c /home/aistudio/work/PaddleOCR/configs/det/det_mv3_db.yml +!python /home/aistudio/work/PaddleOCR/tools/infer_det.py -c /home/aistudio/work/PaddleOCR/configs/det/det_mv3_db.yml +``` +模型推理 +``` +!python /home/aistudio/work/PaddleOCR/tools/infer/predict_det.py --image_dir="/home/aistudio/work/test_img/" --det_model_dir="/home/aistudio/work/PaddleOCR/output/infer" +``` + +## 5 总结 + +1. 分类+检测在一定程度上能够缩短用时,具体的模型选取要根据业务场景恰当选择。 +2. 数据标注需要多次进行测试调整标注方法,一般进行检测模型微调,需要标注至少上百张。 +3. 设置合理的batch_size以及resize大小,同时注意lr设置。 + + +## References + +1 https://github.com/PaddlePaddle/PaddleOCR + +2 https://github.com/PaddlePaddle/PaddleClas + +3 https://blog.csdn.net/YY007H/article/details/124491217 diff --git "a/applications/\346\211\253\346\217\217\345\220\210\345\220\214\345\205\263\351\224\256\344\277\241\346\201\257\346\217\220\345\217\226.md" "b/applications/\346\211\253\346\217\217\345\220\210\345\220\214\345\205\263\351\224\256\344\277\241\346\201\257\346\217\220\345\217\226.md" new file mode 100644 index 0000000000000000000000000000000000000000..26c64a34c14e621b4bc15df67c66141024bf7cc4 --- /dev/null +++ "b/applications/\346\211\253\346\217\217\345\220\210\345\220\214\345\205\263\351\224\256\344\277\241\346\201\257\346\217\220\345\217\226.md" @@ -0,0 +1,284 @@ +# 金融智能核验:扫描合同关键信息抽取 + +本案例将使用OCR技术和通用信息抽取技术,实现合同关键信息审核和比对。通过本章的学习,你可以快速掌握: + +1. 使用PaddleOCR提取扫描文本内容 +2. 使用PaddleNLP抽取自定义信息 + +点击进入 [AI Studio 项目](https://aistudio.baidu.com/aistudio/projectdetail/4545772) + +## 1. 项目背景 +合同审核广泛应用于大中型企业、上市公司、证券、基金公司中,是规避风险的重要任务。 +- 合同内容对比:合同审核场景中,快速找出不同版本合同修改区域、版本差异;如合同盖章归档场景中有效识别实际签署的纸质合同、电子版合同差异。 + +- 合规性检查:法务人员进行合同审核,如合同完备性检查、大小写金额检查、签约主体一致性检查、双方权利和义务对等性分析等。 + +- 风险点识别:通过合同审核可识别事实倾向型风险点和数值计算型风险点等,例如交付地点约定不明、合同总价款不一致、重要条款缺失等风险点。 + + +![](https://ai-studio-static-online.cdn.bcebos.com/d5143df967fa4364a38868793fe7c57b0c0b1213930243babd6ae01423dcbc4d) + +传统业务中大多使用人工进行纸质版合同审核,存在成本高,工作量大,效率低的问题,且一旦出错将造成巨额损失。 + + +本项目针对以上场景,使用PaddleOCR+PaddleNLP快速提取文本内容,经过少量数据微调即可准确抽取关键信息,**高效完成合同内容对比、合规性检查、风险点识别等任务,提高效率,降低风险**。 + +![](https://ai-studio-static-online.cdn.bcebos.com/54f3053e6e1b47a39b26e757006fe2c44910d60a3809422ab76c25396b92e69b) + + +## 2. 解决方案 + +### 2.1 扫描合同文本内容提取 + +使用PaddleOCR开源的模型可以快速完成扫描文档的文本内容提取,在清晰文档上识别准确率可达到95%+。下面来快速体验一下: + +#### 2.1.1 环境准备 + +[PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR)提供了适用于通用场景的高精轻量模型,提供数据预处理-模型推理-后处理全流程,支持pip安装: + +``` +python -m pip install paddleocr +``` + +#### 2.1.2 效果测试 + +使用一张合同图片作为测试样本,感受ppocrv3模型效果: + + + +使用中文检测+识别模型提取文本,实例化PaddleOCR类: + +``` +from paddleocr import PaddleOCR, draw_ocr + +# paddleocr目前支持中英文、英文、法语、德语、韩语、日语等80个语种,可以通过修改lang参数进行切换 +ocr = PaddleOCR(use_angle_cls=False, lang="ch") # need to run only once to download and load model into memory +``` + +一行命令启动预测,预测结果包括`检测框`和`文本识别内容`: + +``` +img_path = "./test_img/hetong2.jpg" +result = ocr.ocr(img_path, cls=False) +for line in result: + print(line) + +# 可视化结果 +from PIL import Image + +image = Image.open(img_path).convert('RGB') +boxes = [line[0] for line in result] +txts = [line[1][0] for line in result] +scores = [line[1][1] for line in result] +im_show = draw_ocr(image, boxes, txts, scores, font_path='./simfang.ttf') +im_show = Image.fromarray(im_show) +im_show.show() +``` + +#### 2.1.3 图片预处理 + +通过上图可视化结果可以看到,印章部分造成的文本遮盖,影响了文本识别结果,因此可以考虑通道提取,去除图片中的红色印章: + +``` +import cv2 +import numpy as np +import matplotlib.pyplot as plt + +#读入图像,三通道 +image=cv2.imread("./test_img/hetong2.jpg",cv2.IMREAD_COLOR) #timg.jpeg + +#获得三个通道 +Bch,Gch,Rch=cv2.split(image) + +#保存三通道图片 +cv2.imwrite('blue_channel.jpg',Bch) +cv2.imwrite('green_channel.jpg',Gch) +cv2.imwrite('red_channel.jpg',Rch) +``` +#### 2.1.4 合同文本信息提取 + +经过2.1.3的预处理后,合同照片的红色通道被分离,获得了一张相对更干净的图片,此时可以再次使用ppocr模型提取文本内容: + +``` +import numpy as np +import cv2 + + +img_path = './red_channel.jpg' +result = ocr.ocr(img_path, cls=False) + +# 可视化结果 +from PIL import Image + +image = Image.open(img_path).convert('RGB') +boxes = [line[0] for line in result] +txts = [line[1][0] for line in result] +scores = [line[1][1] for line in result] +im_show = draw_ocr(image, boxes, txts, scores, font_path='./simfang.ttf') +im_show = Image.fromarray(im_show) +vis = np.array(im_show) +im_show.show() +``` + +忽略检测框内容,提取完整的合同文本: + +``` +txts = [line[1][0] for line in result] +all_context = "\n".join(txts) +print(all_context) +``` + +通过以上环节就完成了扫描合同关键信息抽取的第一步:文本内容提取,接下来可以基于识别出的文本内容抽取关键信息 + +### 2.2 合同关键信息抽取 + +#### 2.2.1 环境准备 + +安装PaddleNLP + + +``` +pip install --upgrade pip +pip install --upgrade paddlenlp +``` + +#### 2.2.2 合同关键信息抽取 + +PaddleNLP 使用 Taskflow 统一管理多场景任务的预测功能,其中`information_extraction` 通过大量的有标签样本进行训练,在通用的场景中一般可以直接使用,只需更换关键字即可。例如在合同信息抽取中,我们重新定义抽取关键字: + +甲方、乙方、币种、金额、付款方式 + + +将使用OCR提取好的文本作为输入,使用三行命令可以对上文中提取到的合同文本进行关键信息抽取: + +``` +from paddlenlp import Taskflow +schema = ["甲方","乙方","总价"] +ie = Taskflow('information_extraction', schema=schema) +ie.set_schema(schema) +ie(all_context) +``` + +可以看到UIE模型可以准确的提取出关键信息,用于后续的信息比对或审核。 + +## 3.效果优化 + +### 3.1 文本识别后处理调优 + +实际图片采集过程中,可能出现部分图片弯曲等问题,导致使用默认参数识别文本时存在漏检,影响关键信息获取。 + +例如下图: + + + + +直接进行预测: + +``` +img_path = "./test_img/hetong3.jpg" +# 预测结果 +result = ocr.ocr(img_path, cls=False) +# 可视化结果 +from PIL import Image + +image = Image.open(img_path).convert('RGB') +boxes = [line[0] for line in result] +txts = [line[1][0] for line in result] +scores = [line[1][1] for line in result] +im_show = draw_ocr(image, boxes, txts, scores, font_path='./simfang.ttf') +im_show = Image.fromarray(im_show) +im_show.show() +``` + +可视化结果可以看到,弯曲图片存在漏检,一般来说可以通过调整后处理参数解决,无需重新训练模型。漏检问题往往是因为检测模型获得的分割图太小,生成框的得分过低被过滤掉了,通常有两种方式调整参数: +- 开启`use_dilatiion=True` 膨胀分割区域 +- 调小`det_db_box_thresh`阈值 + +``` +# 重新实例化 PaddleOCR +ocr = PaddleOCR(use_angle_cls=False, lang="ch", det_db_box_thresh=0.3, use_dilation=True) + +# 预测并可视化 +img_path = "./test_img/hetong3.jpg" +# 预测结果 +result = ocr.ocr(img_path, cls=False) +# 可视化结果 +image = Image.open(img_path).convert('RGB') +boxes = [line[0] for line in result] +txts = [line[1][0] for line in result] +scores = [line[1][1] for line in result] +im_show = draw_ocr(image, boxes, txts, scores, font_path='./simfang.ttf') +im_show = Image.fromarray(im_show) +im_show.show() +``` + +可以看到漏检问题被很好的解决,提取完整的文本内容: + +``` +txts = [line[1][0] for line in result] +context = "\n".join(txts) +print(context) +``` + +### 3.2 关键信息提取调优 + +UIE通过大量有标签样本进行训练,得到了一个开箱即用的高精模型。 然而针对不同场景,可能会出现部分实体无法被抽取的情况。通常来说有以下几个方法进行效果调优: + + +- 修改 schema +- 添加正则方法 +- 标注小样本微调模型 + +**修改schema** + +Prompt和原文描述越像,抽取效果越好,例如 +``` +三:合同价格:总价为人民币大写:参拾玖万捌仟伍佰 +元,小写:398500.00元。总价中包括站房工程建设、安装 +及相关避雷、消防、接地、电力、材料费、检验费、安全、 +验收等所需费用及其他相关费用和税金。 +``` +schema = ["总金额"] 时无法准确抽取,与原文描述差异较大。 修改 schema = ["总价"] 再次尝试: + +``` +from paddlenlp import Taskflow +# schema = ["总金额"] +schema = ["总价"] +ie = Taskflow('information_extraction', schema=schema) +ie.set_schema(schema) +ie(all_context) +``` + + +**模型微调** + +UIE的建模方式主要是通过 `Prompt` 方式来建模, `Prompt` 在小样本上进行微调效果非常有效。详细的数据标注+模型微调步骤可以参考项目: + +[PaddleNLP信息抽取技术重磅升级!](https://aistudio.baidu.com/aistudio/projectdetail/3914778?channelType=0&channel=0) + +[工单信息抽取](https://aistudio.baidu.com/aistudio/projectdetail/3914778?contributionType=1) + +[快递单信息抽取](https://aistudio.baidu.com/aistudio/projectdetail/4038499?contributionType=1) + + +## 总结 + +扫描合同的关键信息提取可以使用 PaddleOCR + PaddleNLP 组合实现,两个工具均有以下优势: + +* 使用简单:whl包一键安装,3行命令调用 +* 效果领先:优秀的模型效果可覆盖几乎全部的应用场景 +* 调优成本低:OCR模型可通过后处理参数的调整适配略有偏差的扫描文本, UIE模型可以通过极少的标注样本微调,成本很低。 + +## 作业 + +尝试自己解析出 `test_img/homework.png` 扫描合同中的 [甲方、乙方] 关键词: + + + + + + + +更多场景下的垂类模型获取,请扫下图二维码填写问卷,加入PaddleOCR官方交流群获取模型下载链接、《动手学OCR》电子书等全套OCR学习资料🎁 + + diff --git a/configs/det/det_r18_vd_ct.yml b/configs/det/det_r18_vd_ct.yml new file mode 100644 index 0000000000000000000000000000000000000000..42922dfd22c0e49d20d50534c76fedae16b27a4a --- /dev/null +++ b/configs/det/det_r18_vd_ct.yml @@ -0,0 +1,107 @@ +Global: + use_gpu: true + epoch_num: 600 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/det_ct/ + save_epoch_step: 10 + # evaluation is run every 2000 iterations + eval_batch_step: [0,1000] + cal_metric_during_train: False + pretrained_model: ./pretrain_models/ResNet18_vd_pretrained.pdparams + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_en/img623.jpg + save_res_path: ./output/det_ct/predicts_ct.txt + +Architecture: + model_type: det + algorithm: CT + Transform: + Backbone: + name: ResNet_vd + layers: 18 + Neck: + name: CTFPN + Head: + name: CT_Head + in_channels: 512 + hidden_dim: 128 + num_classes: 3 + +Loss: + name: CTLoss + +Optimizer: + name: Adam + lr: #PolynomialDecay + name: Linear + learning_rate: 0.001 + end_lr: 0. + epochs: 600 + step_each_epoch: 1254 + power: 0.9 + +PostProcess: + name: CTPostProcess + box_type: poly + +Metric: + name: CTMetric + main_indicator: f_score + +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/total_text/train + label_file_list: + - ./train_data/total_text/train/train.txt + ratio_list: [1.0] + transforms: + - DecodeImage: + img_mode: RGB + channel_first: False + - CTLabelEncode: # Class handling label + - RandomScale: + - MakeShrink: + - GroupRandomHorizontalFlip: + - GroupRandomRotate: + - GroupRandomCropPadding: + - MakeCentripetalShift: + - ColorJitter: + brightness: 0.125 + saturation: 0.5 + - ToCHWImage: + - NormalizeImage: + - KeepKeys: + keep_keys: ['image', 'gt_kernel', 'training_mask', 'gt_instance', 'gt_kernel_instance', 'training_mask_distance', 'gt_distance'] # the order of the dataloader list + loader: + shuffle: True + drop_last: True + batch_size_per_card: 4 + num_workers: 8 + +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data/total_text/test + label_file_list: + - ./train_data/total_text/test/test.txt + ratio_list: [1.0] + transforms: + - DecodeImage: + img_mode: RGB + channel_first: False + - CTLabelEncode: # Class handling label + - ScaleAlignedShort: + - NormalizeImage: + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: ['image', 'shape', 'polys', 'texts'] # the order of the dataloader list + loader: + shuffle: False + drop_last: False + batch_size_per_card: 1 + num_workers: 2 diff --git a/configs/e2e/e2e_r50_vd_pg.yml b/configs/e2e/e2e_r50_vd_pg.yml index c4c5226e796a42db723ce78ef65473e357c25dc6..4642f544868f720d413f7f5242740705bc9fd0a5 100644 --- a/configs/e2e/e2e_r50_vd_pg.yml +++ b/configs/e2e/e2e_r50_vd_pg.yml @@ -13,6 +13,7 @@ Global: save_inference_dir: use_visualdl: False infer_img: + infer_visual_type: EN # two mode: EN is for english datasets, CN is for chinese datasets valid_set: totaltext # two mode: totaltext valid curved words, partvgg valid non-curved words save_res_path: ./output/pgnet_r50_vd_totaltext/predicts_pgnet.txt character_dict_path: ppocr/utils/ic15_dict.txt @@ -32,6 +33,7 @@ Architecture: name: PGFPN Head: name: PGHead + character_dict_path: ppocr/utils/ic15_dict.txt # the same as Global:character_dict_path Loss: name: PGLoss @@ -45,16 +47,18 @@ Optimizer: beta1: 0.9 beta2: 0.999 lr: + name: Cosine learning_rate: 0.001 + warmup_epoch: 50 regularizer: name: 'L2' - factor: 0 - + factor: 0.0001 PostProcess: name: PGPostProcess score_thresh: 0.5 mode: fast # fast or slow two ways + point_gather_mode: align # same as PGProcessTrain: point_gather_mode Metric: name: E2EMetric @@ -76,9 +80,12 @@ Train: - E2ELabelEncodeTrain: - PGProcessTrain: batch_size: 14 # same as loader: batch_size_per_card + use_resize: True + use_random_crop: False min_crop_size: 24 min_text_size: 4 max_text_size: 512 + point_gather_mode: align # two mode: align and none, align mode is better than none mode - KeepKeys: keep_keys: [ 'images', 'tcl_maps', 'tcl_label_maps', 'border_maps','direction_maps', 'training_masks', 'label_list', 'pos_list', 'pos_mask' ] # dataloader will return list in this order loader: diff --git a/configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml b/configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml index 2401cf317987c5614a476065191e750587bc09b5..99dc771d150b15847486c096529a2828b9c0c05a 100644 --- a/configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml +++ b/configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml @@ -68,6 +68,7 @@ Train: - VQAReTokenRelation: - VQAReTokenChunk: max_seq_len: *max_seq_len + - TensorizeEntitiesRelations: - Resize: size: [224,224] - NormalizeImage: @@ -83,7 +84,6 @@ Train: drop_last: False batch_size_per_card: 2 num_workers: 8 - collate_fn: ListCollator Eval: dataset: @@ -105,6 +105,7 @@ Eval: - VQAReTokenRelation: - VQAReTokenChunk: max_seq_len: *max_seq_len + - TensorizeEntitiesRelations: - Resize: size: [224,224] - NormalizeImage: @@ -120,4 +121,3 @@ Eval: drop_last: False batch_size_per_card: 8 num_workers: 8 - collate_fn: ListCollator diff --git a/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yml b/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yml index ea9f50ef56ec8b169333263c1d5e96586f9472b3..e65af0a064418f1f21725f6b9e249a8be8391f41 100644 --- a/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yml +++ b/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yml @@ -73,6 +73,7 @@ Train: - VQAReTokenRelation: - VQAReTokenChunk: max_seq_len: *max_seq_len + - TensorizeEntitiesRelations: - Resize: size: [224,224] - NormalizeImage: @@ -82,13 +83,12 @@ Train: order: 'hwc' - ToCHWImage: - KeepKeys: - keep_keys: [ 'input_ids', 'bbox','attention_mask', 'token_type_ids', 'image', 'entities', 'relations'] # dataloader will return list in this order + keep_keys: [ 'input_ids', 'bbox','attention_mask', 'token_type_ids', 'entities', 'relations'] # dataloader will return list in this order loader: shuffle: True drop_last: False batch_size_per_card: 2 num_workers: 4 - collate_fn: ListCollator Eval: dataset: @@ -112,6 +112,7 @@ Eval: - VQAReTokenRelation: - VQAReTokenChunk: max_seq_len: *max_seq_len + - TensorizeEntitiesRelations: - Resize: size: [224,224] - NormalizeImage: @@ -121,11 +122,9 @@ Eval: order: 'hwc' - ToCHWImage: - KeepKeys: - keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'entities', 'relations'] # dataloader will return list in this order + keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'entities', 'relations'] # dataloader will return list in this order loader: shuffle: False drop_last: False batch_size_per_card: 8 num_workers: 8 - collate_fn: ListCollator - diff --git a/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml b/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml index b96528d2738e7cfb2575feca4146af1eed0c5d2f..eda9fcddb9abdd2611dd851566c0f327278b51fc 100644 --- a/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml +++ b/configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml @@ -57,14 +57,16 @@ Loss: mode: "l2" model_name_pairs: - ["Student", "Teacher"] - key: hidden_states_5 + key: hidden_states + index: 5 name: "loss_5" - DistillationVQADistanceLoss: weight: 0.5 mode: "l2" model_name_pairs: - ["Student", "Teacher"] - key: hidden_states_8 + key: hidden_states + index: 8 name: "loss_8" @@ -116,6 +118,7 @@ Train: - VQAReTokenRelation: - VQAReTokenChunk: max_seq_len: *max_seq_len + - TensorizeEntitiesRelations: - Resize: size: [224,224] - NormalizeImage: @@ -125,13 +128,12 @@ Train: order: 'hwc' - ToCHWImage: - KeepKeys: - keep_keys: [ 'input_ids', 'bbox','attention_mask', 'token_type_ids', 'image', 'entities', 'relations'] # dataloader will return list in this order + keep_keys: [ 'input_ids', 'bbox','attention_mask', 'token_type_ids', 'entities', 'relations'] # dataloader will return list in this order loader: shuffle: True drop_last: False batch_size_per_card: 2 num_workers: 4 - collate_fn: ListCollator Eval: dataset: @@ -155,6 +157,7 @@ Eval: - VQAReTokenRelation: - VQAReTokenChunk: max_seq_len: *max_seq_len + - TensorizeEntitiesRelations: - Resize: size: [224,224] - NormalizeImage: @@ -164,12 +167,11 @@ Eval: order: 'hwc' - ToCHWImage: - KeepKeys: - keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'entities', 'relations'] # dataloader will return list in this order + keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'entities', 'relations'] # dataloader will return list in this order loader: shuffle: False drop_last: False batch_size_per_card: 8 num_workers: 8 - collate_fn: ListCollator diff --git a/configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh_udml.yml b/configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh_udml.yml index 238bbd2b2c7083b5534062afd3e6c11a87494a56..79abe540936b9dd54ac04a935059e784d3fea153 100644 --- a/configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh_udml.yml +++ b/configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh_udml.yml @@ -70,14 +70,16 @@ Loss: mode: "l2" model_name_pairs: - ["Student", "Teacher"] - key: hidden_states_5 + key: hidden_states + index: 5 name: "loss_5" - DistillationVQADistanceLoss: weight: 0.5 mode: "l2" model_name_pairs: - ["Student", "Teacher"] - key: hidden_states_8 + key: hidden_states + index: 8 name: "loss_8" diff --git a/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml b/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml index 6453934b7324b2b351aeb6fdf8e4e4de24b022bf..7e98280b32558b8d3d203084e6e327bc7cd782bf 100644 --- a/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml +++ b/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml @@ -88,6 +88,7 @@ Train: prob: 0.5 ext_data_num: 2 image_shape: [48, 320, 3] + max_text_length: *max_text_length - RecAug: - MultiLabelEncode: - RecResizeImg: diff --git a/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml b/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml index e7cbae59a14af73639e1a74a14021b9b2ef60057..427255738696d8e6a073829350c40b00ef30115f 100644 --- a/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml +++ b/configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml @@ -162,6 +162,7 @@ Train: prob: 0.5 ext_data_num: 2 image_shape: [48, 320, 3] + max_text_length: *max_text_length - RecAug: - MultiLabelEncode: - RecResizeImg: diff --git a/configs/rec/PP-OCRv3/en_PP-OCRv3_rec.yml b/configs/rec/PP-OCRv3/en_PP-OCRv3_rec.yml index ff536edec4d6e7a85a6e6c189d56a23ffabc5583..c728e0ac823b0bf835322dcbd0c385c3ac7b2489 100644 --- a/configs/rec/PP-OCRv3/en_PP-OCRv3_rec.yml +++ b/configs/rec/PP-OCRv3/en_PP-OCRv3_rec.yml @@ -88,6 +88,7 @@ Train: prob: 0.5 ext_data_num: 2 image_shape: [48, 320, 3] + max_text_length: *max_text_length - RecAug: - MultiLabelEncode: - RecResizeImg: diff --git a/configs/rec/rec_r31_robustscanner.yml b/configs/rec/rec_r31_robustscanner.yml index 40d39aee3c42c18085ace035944dba057b923245..54b69d456ef67c88289504ed5cf8719588ea0803 100644 --- a/configs/rec/rec_r31_robustscanner.yml +++ b/configs/rec/rec_r31_robustscanner.yml @@ -12,7 +12,7 @@ Global: checkpoints: save_inference_dir: use_visualdl: False - infer_img: ./inference/rec_inference + infer_img: doc/imgs_words_en/word_10.png # for data or label process character_dict_path: ppocr/utils/dict90.txt max_text_length: &max_text_length 40 diff --git a/configs/rec/rec_r32_gaspin_bilstm_att.yml b/configs/rec/rec_r32_gaspin_bilstm_att.yml index aea71388f703376120af4d0caf2fa8ccd4d92cce..91d3e104188c04f4372a6e574b7fc07608234a2e 100644 --- a/configs/rec/rec_r32_gaspin_bilstm_att.yml +++ b/configs/rec/rec_r32_gaspin_bilstm_att.yml @@ -12,7 +12,7 @@ Global: checkpoints: save_inference_dir: use_visualdl: False - infer_img: doc/imgs_words/ch/word_1.jpg + infer_img: doc/imgs_words_en/word_10.png # for data or label process character_dict_path: ./ppocr/utils/dict/spin_dict.txt max_text_length: 25 diff --git a/configs/table/SLANet.yml b/configs/table/SLANet.yml index 384c95852e815f9780328f63cbbd52fa0ef3deb4..a896614556e36f77bd784218b6c2f29914219dbe 100644 --- a/configs/table/SLANet.yml +++ b/configs/table/SLANet.yml @@ -12,7 +12,7 @@ Global: checkpoints: save_inference_dir: ./output/SLANet/infer use_visualdl: False - infer_img: doc/table/table.jpg + infer_img: ppstructure/docs/table/table.jpg # for data or label process character_dict_path: ppocr/utils/dict/table_structure_dict.txt character_type: en diff --git a/configs/table/SLANet_ch.yml b/configs/table/SLANet_ch.yml index 997ff0a77b5ea824957abc1d32a7ba7f70abc12c..3b1e5c6bd9dd4cd2a084d557a1285983a56bdf2a 100644 --- a/configs/table/SLANet_ch.yml +++ b/configs/table/SLANet_ch.yml @@ -12,7 +12,7 @@ Global: checkpoints: save_inference_dir: ./output/SLANet_ch/infer use_visualdl: False - infer_img: doc/table/table.jpg + infer_img: ppstructure/docs/table/table.jpg # for data or label process character_dict_path: ppocr/utils/dict/table_structure_dict_ch.txt character_type: en @@ -107,7 +107,7 @@ Train: Eval: dataset: name: PubTabDataSet - data_dir: train_data/table/val/ + data_dir: train_data/table/val/ label_file_list: [train_data/table/val.txt] transforms: - DecodeImage: diff --git a/configs/table/table_mv3.yml b/configs/table/table_mv3.yml index 9355a236e15b60db18e8715c2702701fd5d36c71..9d286f4153eaab44bf0d259bbad4a0b3b8ada568 100755 --- a/configs/table/table_mv3.yml +++ b/configs/table/table_mv3.yml @@ -43,7 +43,6 @@ Architecture: Head: name: TableAttentionHead hidden_size: 256 - loc_type: 2 max_text_length: *max_text_length loc_reg_num: &loc_reg_num 4 diff --git a/deploy/cpp_infer/include/args.h b/deploy/cpp_infer/include/args.h index e0dd8bbcd1044fd695c90805bc770de5b47e51cf..e6e76ef927c16f6afe381f64ea8dde4ac99185cf 100644 --- a/deploy/cpp_infer/include/args.h +++ b/deploy/cpp_infer/include/args.h @@ -49,13 +49,20 @@ DECLARE_int32(rec_batch_num); DECLARE_string(rec_char_dict_path); DECLARE_int32(rec_img_h); DECLARE_int32(rec_img_w); +// layout model related +DECLARE_string(layout_model_dir); +DECLARE_string(layout_dict_path); +DECLARE_double(layout_score_threshold); +DECLARE_double(layout_nms_threshold); // structure model related DECLARE_string(table_model_dir); DECLARE_int32(table_max_len); DECLARE_int32(table_batch_num); DECLARE_string(table_char_dict_path); +DECLARE_bool(merge_no_span_structure); // forward related DECLARE_bool(det); DECLARE_bool(rec); DECLARE_bool(cls); -DECLARE_bool(table); \ No newline at end of file +DECLARE_bool(table); +DECLARE_bool(layout); \ No newline at end of file diff --git a/deploy/cpp_infer/include/ocr_cls.h b/deploy/cpp_infer/include/ocr_cls.h index f5429a7c5bc58c2640f042811ad0eed23f29feba..f5a0356573b3219865e0c9fe08d57358d3a2c88c 100644 --- a/deploy/cpp_infer/include/ocr_cls.h +++ b/deploy/cpp_infer/include/ocr_cls.h @@ -14,26 +14,12 @@ #pragma once -#include "opencv2/core.hpp" -#include "opencv2/imgcodecs.hpp" -#include "opencv2/imgproc.hpp" #include "paddle_api.h" #include "paddle_inference_api.h" -#include -#include -#include -#include -#include - -#include -#include -#include #include #include -using namespace paddle_infer; - namespace PaddleOCR { class Classifier { @@ -66,7 +52,7 @@ public: std::vector &cls_scores, std::vector ×); private: - std::shared_ptr predictor_; + std::shared_ptr predictor_; bool use_gpu_ = false; int gpu_id_ = 0; diff --git a/deploy/cpp_infer/include/ocr_det.h b/deploy/cpp_infer/include/ocr_det.h index d1421b103b28b44e15a7df53a63fd893ca60e529..9f6f2520540f96dfa53f5c4c907317bb8ff04013 100644 --- a/deploy/cpp_infer/include/ocr_det.h +++ b/deploy/cpp_infer/include/ocr_det.h @@ -14,26 +14,12 @@ #pragma once -#include "opencv2/core.hpp" -#include "opencv2/imgcodecs.hpp" -#include "opencv2/imgproc.hpp" #include "paddle_api.h" #include "paddle_inference_api.h" -#include -#include -#include -#include -#include - -#include -#include -#include #include #include -using namespace paddle_infer; - namespace PaddleOCR { class DBDetector { @@ -41,7 +27,7 @@ public: explicit DBDetector(const std::string &model_dir, const bool &use_gpu, const int &gpu_id, const int &gpu_mem, const int &cpu_math_library_num_threads, - const bool &use_mkldnn, const string &limit_type, + const bool &use_mkldnn, const std::string &limit_type, const int &limit_side_len, const double &det_db_thresh, const double &det_db_box_thresh, const double &det_db_unclip_ratio, @@ -77,7 +63,7 @@ public: std::vector ×); private: - std::shared_ptr predictor_; + std::shared_ptr predictor_; bool use_gpu_ = false; int gpu_id_ = 0; @@ -85,7 +71,7 @@ private: int cpu_math_library_num_threads_ = 4; bool use_mkldnn_ = false; - string limit_type_ = "max"; + std::string limit_type_ = "max"; int limit_side_len_ = 960; double det_db_thresh_ = 0.3; diff --git a/deploy/cpp_infer/include/ocr_rec.h b/deploy/cpp_infer/include/ocr_rec.h index 30f8efa9996a62adc74717dd46f2aef7fc96b091..257c261033bf8f8c0ce605ba90cedfbb49d844dc 100644 --- a/deploy/cpp_infer/include/ocr_rec.h +++ b/deploy/cpp_infer/include/ocr_rec.h @@ -14,27 +14,12 @@ #pragma once -#include "opencv2/core.hpp" -#include "opencv2/imgcodecs.hpp" -#include "opencv2/imgproc.hpp" #include "paddle_api.h" #include "paddle_inference_api.h" -#include -#include -#include -#include -#include - -#include -#include -#include #include -#include #include -using namespace paddle_infer; - namespace PaddleOCR { class CRNNRecognizer { @@ -42,7 +27,7 @@ public: explicit CRNNRecognizer(const std::string &model_dir, const bool &use_gpu, const int &gpu_id, const int &gpu_mem, const int &cpu_math_library_num_threads, - const bool &use_mkldnn, const string &label_path, + const bool &use_mkldnn, const std::string &label_path, const bool &use_tensorrt, const std::string &precision, const int &rec_batch_num, const int &rec_img_h, @@ -75,7 +60,7 @@ public: std::vector &rec_text_scores, std::vector ×); private: - std::shared_ptr predictor_; + std::shared_ptr predictor_; bool use_gpu_ = false; int gpu_id_ = 0; diff --git a/deploy/cpp_infer/include/paddleocr.h b/deploy/cpp_infer/include/paddleocr.h index a2c60b14acceaa90a8d8e4a70ccc50f02f254eb6..16750a15f70d374f8aa837042ba6a13bc10a5d35 100644 --- a/deploy/cpp_infer/include/paddleocr.h +++ b/deploy/cpp_infer/include/paddleocr.h @@ -14,28 +14,9 @@ #pragma once -#include "opencv2/core.hpp" -#include "opencv2/imgcodecs.hpp" -#include "opencv2/imgproc.hpp" -#include "paddle_api.h" -#include "paddle_inference_api.h" -#include -#include -#include -#include -#include - -#include -#include -#include - #include #include #include -#include -#include - -using namespace paddle_infer; namespace PaddleOCR { @@ -43,21 +24,27 @@ class PPOCR { public: explicit PPOCR(); ~PPOCR(); - std::vector> - ocr(std::vector cv_all_img_names, bool det = true, - bool rec = true, bool cls = true); + + std::vector> ocr(std::vector img_list, + bool det = true, + bool rec = true, + bool cls = true); + std::vector ocr(cv::Mat img, bool det = true, + bool rec = true, bool cls = true); + + void reset_timer(); + void benchmark_log(int img_num); protected: - void det(cv::Mat img, std::vector &ocr_results, - std::vector ×); + std::vector time_info_det = {0, 0, 0}; + std::vector time_info_rec = {0, 0, 0}; + std::vector time_info_cls = {0, 0, 0}; + + void det(cv::Mat img, std::vector &ocr_results); void rec(std::vector img_list, - std::vector &ocr_results, - std::vector ×); + std::vector &ocr_results); void cls(std::vector img_list, - std::vector &ocr_results, - std::vector ×); - void log(std::vector &det_times, std::vector &rec_times, - std::vector &cls_times, int img_num); + std::vector &ocr_results); private: DBDetector *detector_ = nullptr; diff --git a/deploy/cpp_infer/include/paddlestructure.h b/deploy/cpp_infer/include/paddlestructure.h index b30ac045b2a6552b69442b2e8b29673efc820e31..8478a85cdec23984f86a323f55a4591d52bcf08c 100644 --- a/deploy/cpp_infer/include/paddlestructure.h +++ b/deploy/cpp_infer/include/paddlestructure.h @@ -14,27 +14,9 @@ #pragma once -#include "opencv2/core.hpp" -#include "opencv2/imgcodecs.hpp" -#include "opencv2/imgproc.hpp" -#include "paddle_api.h" -#include "paddle_inference_api.h" -#include -#include -#include -#include -#include - -#include -#include -#include - #include -#include +#include #include -#include - -using namespace paddle_infer; namespace PaddleOCR { @@ -42,27 +24,32 @@ class PaddleStructure : public PPOCR { public: explicit PaddleStructure(); ~PaddleStructure(); - std::vector> - structure(std::vector cv_all_img_names, bool layout = false, - bool table = true); + + std::vector structure(cv::Mat img, + bool layout = false, + bool table = true, + bool ocr = false); + + void reset_timer(); + void benchmark_log(int img_num); private: - StructureTableRecognizer *recognizer_ = nullptr; + std::vector time_info_table = {0, 0, 0}; + std::vector time_info_layout = {0, 0, 0}; + + StructureTableRecognizer *table_model_ = nullptr; + StructureLayoutRecognizer *layout_model_ = nullptr; + + void layout(cv::Mat img, + std::vector &structure_result); + + void table(cv::Mat img, StructurePredictResult &structure_result); - void table(cv::Mat img, StructurePredictResult &structure_result, - std::vector &time_info_table, - std::vector &time_info_det, - std::vector &time_info_rec, - std::vector &time_info_cls); - std::string - rebuild_table(std::vector rec_html_tags, - std::vector>> rec_boxes, - std::vector &ocr_result); + std::string rebuild_table(std::vector rec_html_tags, + std::vector> rec_boxes, + std::vector &ocr_result); - float iou(std::vector> &box1, - std::vector> &box2); - float dis(std::vector> &box1, - std::vector> &box2); + float dis(std::vector &box1, std::vector &box2); static bool comparison_dis(const std::vector &dis1, const std::vector &dis2) { diff --git a/deploy/cpp_infer/include/postprocess_op.h b/deploy/cpp_infer/include/postprocess_op.h index 77b3f8b660bda29815245b31ab8cac479b24498f..e267eeee1dd8055b05bb10c89149ad31779aabc7 100644 --- a/deploy/cpp_infer/include/postprocess_op.h +++ b/deploy/cpp_infer/include/postprocess_op.h @@ -14,24 +14,9 @@ #pragma once -#include "opencv2/core.hpp" -#include "opencv2/imgcodecs.hpp" -#include "opencv2/imgproc.hpp" -#include -#include -#include -#include -#include - -#include -#include -#include - #include "include/clipper.h" #include "include/utility.h" -using namespace std; - namespace PaddleOCR { class DBPostProcessor { @@ -92,14 +77,13 @@ private: class TablePostProcessor { public: - void init(std::string label_path); - void - Run(std::vector &loc_preds, std::vector &structure_probs, - std::vector &rec_scores, std::vector &loc_preds_shape, - std::vector &structure_probs_shape, - std::vector> &rec_html_tag_batch, - std::vector>>> &rec_boxes_batch, - std::vector &width_list, std::vector &height_list); + void init(std::string label_path, bool merge_no_span_structure = true); + void Run(std::vector &loc_preds, std::vector &structure_probs, + std::vector &rec_scores, std::vector &loc_preds_shape, + std::vector &structure_probs_shape, + std::vector> &rec_html_tag_batch, + std::vector>> &rec_boxes_batch, + std::vector &width_list, std::vector &height_list); private: std::vector label_list_; @@ -107,4 +91,27 @@ private: std::string beg = "sos"; }; +class PicodetPostProcessor { +public: + void init(std::string label_path, const double score_threshold = 0.4, + const double nms_threshold = 0.5, + const std::vector &fpn_stride = {8, 16, 32, 64}); + void Run(std::vector &results, + std::vector> outs, std::vector ori_shape, + std::vector resize_shape, int eg_max); + std::vector fpn_stride_ = {8, 16, 32, 64}; + +private: + StructurePredictResult disPred2Bbox(std::vector bbox_pred, int label, + float score, int x, int y, int stride, + std::vector im_shape, int reg_max); + void nms(std::vector &input_boxes, + float nms_threshold); + + std::vector label_list_; + double score_threshold_ = 0.4; + double nms_threshold_ = 0.5; + int num_class_ = 5; +}; + } // namespace PaddleOCR diff --git a/deploy/cpp_infer/include/preprocess_op.h b/deploy/cpp_infer/include/preprocess_op.h index 078f19d5b808c81e88d7aa464d6bfaca7fe1b14e..0b2e18330cbb5d8455cc17a508ab1f12de0f389a 100644 --- a/deploy/cpp_infer/include/preprocess_op.h +++ b/deploy/cpp_infer/include/preprocess_op.h @@ -14,21 +14,12 @@ #pragma once -#include "opencv2/core.hpp" -#include "opencv2/imgcodecs.hpp" -#include "opencv2/imgproc.hpp" -#include -#include #include -#include #include -#include -#include -#include - -using namespace std; -using namespace paddle; +#include "opencv2/core.hpp" +#include "opencv2/imgcodecs.hpp" +#include "opencv2/imgproc.hpp" namespace PaddleOCR { @@ -51,9 +42,9 @@ public: class ResizeImgType0 { public: - virtual void Run(const cv::Mat &img, cv::Mat &resize_img, string limit_type, - int limit_side_len, float &ratio_h, float &ratio_w, - bool use_tensorrt); + virtual void Run(const cv::Mat &img, cv::Mat &resize_img, + std::string limit_type, int limit_side_len, float &ratio_h, + float &ratio_w, bool use_tensorrt); }; class CrnnResizeImg { @@ -82,4 +73,10 @@ public: const int max_len = 488); }; +class Resize { +public: + virtual void Run(const cv::Mat &img, cv::Mat &resize_img, const int h, + const int w); +}; + } // namespace PaddleOCR \ No newline at end of file diff --git a/deploy/cpp_infer/include/structure_layout.h b/deploy/cpp_infer/include/structure_layout.h new file mode 100644 index 0000000000000000000000000000000000000000..3dd605720fa1dc009e8f1b28768d221678df713e --- /dev/null +++ b/deploy/cpp_infer/include/structure_layout.h @@ -0,0 +1,78 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle_api.h" +#include "paddle_inference_api.h" + +#include +#include + +namespace PaddleOCR { + +class StructureLayoutRecognizer { +public: + explicit StructureLayoutRecognizer( + const std::string &model_dir, const bool &use_gpu, const int &gpu_id, + const int &gpu_mem, const int &cpu_math_library_num_threads, + const bool &use_mkldnn, const std::string &label_path, + const bool &use_tensorrt, const std::string &precision, + const double &layout_score_threshold, + const double &layout_nms_threshold) { + this->use_gpu_ = use_gpu; + this->gpu_id_ = gpu_id; + this->gpu_mem_ = gpu_mem; + this->cpu_math_library_num_threads_ = cpu_math_library_num_threads; + this->use_mkldnn_ = use_mkldnn; + this->use_tensorrt_ = use_tensorrt; + this->precision_ = precision; + + this->post_processor_.init(label_path, layout_score_threshold, + layout_nms_threshold); + LoadModel(model_dir); + } + + // Load Paddle inference model + void LoadModel(const std::string &model_dir); + + void Run(cv::Mat img, std::vector &result, + std::vector ×); + +private: + std::shared_ptr predictor_; + + bool use_gpu_ = false; + int gpu_id_ = 0; + int gpu_mem_ = 4000; + int cpu_math_library_num_threads_ = 4; + bool use_mkldnn_ = false; + + std::vector mean_ = {0.485f, 0.456f, 0.406f}; + std::vector scale_ = {1 / 0.229f, 1 / 0.224f, 1 / 0.225f}; + bool is_scale_ = true; + + bool use_tensorrt_ = false; + std::string precision_ = "fp32"; + + // pre-process + Resize resize_op_; + Normalize normalize_op_; + Permute permute_op_; + + // post-process + PicodetPostProcessor post_processor_; +}; + +} // namespace PaddleOCR \ No newline at end of file diff --git a/deploy/cpp_infer/include/structure_table.h b/deploy/cpp_infer/include/structure_table.h index 7449c6cd0e158425bccb75740191dd0b6d6ecc9b..616e95d212c948ab165bc73da7758a263583eb98 100644 --- a/deploy/cpp_infer/include/structure_table.h +++ b/deploy/cpp_infer/include/structure_table.h @@ -14,26 +14,11 @@ #pragma once -#include "opencv2/core.hpp" -#include "opencv2/imgcodecs.hpp" -#include "opencv2/imgproc.hpp" #include "paddle_api.h" #include "paddle_inference_api.h" -#include -#include -#include -#include -#include - -#include -#include -#include #include #include -#include - -using namespace paddle_infer; namespace PaddleOCR { @@ -42,9 +27,10 @@ public: explicit StructureTableRecognizer( const std::string &model_dir, const bool &use_gpu, const int &gpu_id, const int &gpu_mem, const int &cpu_math_library_num_threads, - const bool &use_mkldnn, const string &label_path, + const bool &use_mkldnn, const std::string &label_path, const bool &use_tensorrt, const std::string &precision, - const int &table_batch_num, const int &table_max_len) { + const int &table_batch_num, const int &table_max_len, + const bool &merge_no_span_structure) { this->use_gpu_ = use_gpu; this->gpu_id_ = gpu_id; this->gpu_mem_ = gpu_mem; @@ -55,7 +41,7 @@ public: this->table_batch_num_ = table_batch_num; this->table_max_len_ = table_max_len; - this->post_processor_.init(label_path); + this->post_processor_.init(label_path, merge_no_span_structure); LoadModel(model_dir); } @@ -65,11 +51,11 @@ public: void Run(std::vector img_list, std::vector> &rec_html_tags, std::vector &rec_scores, - std::vector>>> &rec_boxes, + std::vector>> &rec_boxes, std::vector ×); private: - std::shared_ptr predictor_; + std::shared_ptr predictor_; bool use_gpu_ = false; int gpu_id_ = 0; diff --git a/deploy/cpp_infer/include/utility.h b/deploy/cpp_infer/include/utility.h index 520804f64529303b5ecec27dc5f0895f1fff5c72..7dfe03dd625e7b31bc64d875c893ea132b46423c 100644 --- a/deploy/cpp_infer/include/utility.h +++ b/deploy/cpp_infer/include/utility.h @@ -41,11 +41,13 @@ struct OCRPredictResult { }; struct StructurePredictResult { - std::vector box; + std::vector box; + std::vector> cell_box; std::string type; std::vector text_res; std::string html; float html_score = -1; + float confidence; }; class Utility { @@ -56,6 +58,10 @@ public: const std::vector &ocr_result, const std::string &save_path); + static void VisualizeBboxes(const cv::Mat &srcimg, + const StructurePredictResult &structure_result, + const std::string &save_path); + template inline static size_t argmax(ForwardIterator first, ForwardIterator last) { return std::distance(first, std::max_element(first, last)); @@ -77,10 +83,20 @@ public: static void print_result(const std::vector &ocr_result); - static cv::Mat crop_image(cv::Mat &img, std::vector &area); + static cv::Mat crop_image(cv::Mat &img, const std::vector &area); + static cv::Mat crop_image(cv::Mat &img, const std::vector &area); static void sorted_boxes(std::vector &ocr_result); + static std::vector xyxyxyxy2xyxy(std::vector> &box); + static std::vector xyxyxyxy2xyxy(std::vector &box); + + static float fast_exp(float x); + static std::vector + activation_function_softmax(std::vector &src); + static float iou(std::vector &box1, std::vector &box2); + static float iou(std::vector &box1, std::vector &box2); + private: static bool comparison_box(const OCRPredictResult &result1, const OCRPredictResult &result2) { diff --git a/deploy/cpp_infer/readme.md b/deploy/cpp_infer/readme.md index 2afdf79521223c4f473ded8d4f930546fb762c46..d176ff986295088a15f4e20b16a7986c3640387b 100644 --- a/deploy/cpp_infer/readme.md +++ b/deploy/cpp_infer/readme.md @@ -174,6 +174,9 @@ inference/ |-- table | |--inference.pdiparams | |--inference.pdmodel +|-- layout +| |--inference.pdiparams +| |--inference.pdmodel ``` @@ -278,8 +281,30 @@ Specifically, --cls=true \ ``` +##### 7. layout+table +```shell +./build/ppocr --det_model_dir=inference/det_db \ + --rec_model_dir=inference/rec_rcnn \ + --table_model_dir=inference/table \ + --image_dir=../../ppstructure/docs/table/table.jpg \ + --layout_model_dir=inference/layout \ + --type=structure \ + --table=true \ + --layout=true +``` + +##### 8. layout +```shell +./build/ppocr --layout_model_dir=inference/layout \ + --image_dir=../../ppstructure/docs/table/1.png \ + --type=structure \ + --table=false \ + --layout=true \ + --det=false \ + --rec=false +``` -##### 7. table +##### 9. table ```shell ./build/ppocr --det_model_dir=inference/det_db \ --rec_model_dir=inference/rec_rcnn \ @@ -343,6 +368,16 @@ More parameters are as follows, |rec_img_h|int|48|image height of recognition| |rec_img_w|int|320|image width of recognition| +- Layout related parameters + +|parameter|data type|default|meaning| +| :---: | :---: | :---: | :---: | +|layout_model_dir|string|-| Address of layout inference model| +|layout_dict_path|string|../../ppocr/utils/dict/layout_dict/layout_publaynet_dict.txt|dictionary file| +|layout_score_threshold|float|0.5|Threshold of score.| +|layout_nms_threshold|float|0.5|Threshold of nms.| + + - Table recognition related parameters |parameter|data type|default|meaning| @@ -350,6 +385,7 @@ More parameters are as follows, |table_model_dir|string|-|Address of table recognition inference model| |table_char_dict_path|string|../../ppocr/utils/dict/table_structure_dict.txt|dictionary file| |table_max_len|int|488|The size of the long side of the input image of the table recognition model, the final input image size of the network is(table_max_len,table_max_len)| +|merge_no_span_structure|bool|true|Whether to merge and to
MethodsRPFFPS
SegLink [26]70.086.077.08.9
PixelLink [4]73.283.077.8-
TextSnake [18]73.983.278.31.1
TextField [37]75.987.481.35.2
MSR[38]76.787.481.7-
FTSN [3]77.187.682.0-
LSE[30]81.784.282.9-
CRAFT [2]78.288.282.98.6
MCN [16]798883-
ATRR[35]82.185.283.6-
PAN [34]83.884.484.130.2
DB[12]79.291.584.932.0
DRRG [41]82.3088.0585.08-
Ours (SynText)80.6885.4082.9712.68
Ours (MLT-17)84.5486.6285.5712.31
+predict img: ../../ppstructure/docs/table/1.png +0 type: text, region: [12,729,410,848], score: 0.781044, res: count of ocr result is : 7 +********** print ocr result ********** +0 det boxes: [[4,1],[79,1],[79,12],[4,12]] rec text: CTW1500. rec score: 0.769472 +... +6 det boxes: [[4,99],[391,99],[391,112],[4,112]] rec text: sate-of-the-artmethods[12.34.36l.ourapproachachieves rec score: 0.90414 +********** end print ocr result ********** +1 type: text, region: [69,342,342,359], score: 0.703666, res: count of ocr result is : 1 +********** print ocr result ********** +0 det boxes: [[8,2],[269,2],[269,13],[8,13]] rec text: Table6.Experimentalresults on CTW-1500 rec score: 0.890454 +********** end print ocr result ********** +2 type: text, region: [70,316,706,332], score: 0.659738, res: count of ocr result is : 2 +********** print ocr result ********** +0 det boxes: [[373,2],[630,2],[630,11],[373,11]] rec text: oroposals.andthegreencontoursarefinal rec score: 0.919729 +1 det boxes: [[8,3],[357,3],[357,11],[8,11]] rec text: Visualexperimentalresultshebluecontoursareboundar rec score: 0.915963 +********** end print ocr result ********** +3 type: text, region: [489,342,789,359], score: 0.630538, res: count of ocr result is : 1 +********** print ocr result ********** +0 det boxes: [[8,2],[294,2],[294,14],[8,14]] rec text: Table7.Experimentalresults onMSRA-TD500 rec score: 0.942251 +********** end print ocr result ********** +4 type: text, region: [444,751,841,848], score: 0.607345, res: count of ocr result is : 5 +********** print ocr result ********** +0 det boxes: [[19,3],[389,3],[389,17],[19,17]] rec text: Inthispaper,weproposeanovel adaptivebound rec score: 0.941031 +1 det boxes: [[4,22],[390,22],[390,36],[4,36]] rec text: aryproposalnetworkforarbitraryshapetextdetection rec score: 0.960172 +2 det boxes: [[4,42],[392,42],[392,56],[4,56]] rec text: whichadoptanboundaryproposalmodeltogeneratecoarse rec score: 0.934647 +3 det boxes: [[4,61],[389,61],[389,75],[4,75]] rec text: ooundaryproposals,andthenadoptanadaptiveboundary rec score: 0.946296 +4 det boxes: [[5,80],[387,80],[387,93],[5,93]] rec text: leformationmodelcombinedwithGCNandRNNtoper rec score: 0.952401 +********** end print ocr result ********** +5 type: title, region: [444,705,564,724], score: 0.785429, res: count of ocr result is : 1 +********** print ocr result ********** +0 det boxes: [[6,2],[113,2],[113,14],[6,14]] rec text: 5.Conclusion rec score: 0.856903 +********** end print ocr result ********** +6 type: table, region: [14,360,402,711], score: 0.963643, res:
MethodsExtRPFFPS
TextSnake [18]Syn85.367.975.6
CSE [17]MiLT76.178.777.40.38
LOMO[40]Syn76.585.780.84.4
ATRR[35]Sy-80.280.180.1-
SegLink++ [28]Syn79.882.881.3-
TextField [37]Syn79.883.081.46.0
MSR[38]Syn79.084.181.54.3
PSENet-1s [33]MLT79.784.882.23.9
DB [12]Syn80.286.983.422.0
CRAFT [2]Syn81.186.083.5-
TextDragon [5]MLT+82.884.583.6
PAN [34]Syn81.286.483.739.8
ContourNet [36]84.183.783.94.5
DRRG [41]MLT83.0285.9384.45-
TextPerception[23]Syn81.987.584.6
Ours Syn80.5787.6683.9712.08
Ours81.4587.8184.5112.15
OursMLT83.6086.4585.0012.21
+The table visualized image saved in ./output//6_1.png +7 type: table, region: [462,359,820,657], score: 0.953917, res:
MethodsRPFFPS
SegLink [26]70.086.077.08.9
PixelLink [4]73.283.077.8-
TextSnake [18]73.983.278.31.1
TextField [37]75.987.481.35.2
MSR[38]76.787.481.7-
FTSN[3]77.187.682.0:
LSE[30]81.784.282.9
CRAFT [2]78.288.282.98.6
MCN [16]798883-
ATRR[35]82.185.283.6-
PAN [34]83.884.484.130.2
DB[12]79.291.584.932.0
DRRG [41]82.3088.0585.08-
Ours (SynText)80.6885.4082.9712.68
Ours (MLT-17)84.5486.6285.5712.31
+The table visualized image saved in ./output//7_1.png +8 type: figure, region: [14,3,836,310], score: 0.969443, res: count of ocr result is : 26 +********** print ocr result ********** +0 det boxes: [[506,14],[539,15],[539,22],[506,21]] rec text: E rec score: 0.318073 +... +25 det boxes: [[680,290],[759,288],[759,303],[680,305]] rec text: (d) CTW1500 rec score: 0.95911 +********** end print ocr result ********** ``` diff --git a/deploy/cpp_infer/readme_ch.md b/deploy/cpp_infer/readme_ch.md index d94c95c8c5a5bf02d5b9fb4f16edd4da8ebe1e3f..444567f193abade94029d0f048675eaf1cf03690 100644 --- a/deploy/cpp_infer/readme_ch.md +++ b/deploy/cpp_infer/readme_ch.md @@ -184,6 +184,9 @@ inference/ |-- table | |--inference.pdiparams | |--inference.pdmodel +|-- layout +| |--inference.pdiparams +| |--inference.pdmodel ``` @@ -288,7 +291,30 @@ CUDNN_LIB_DIR=/your_cudnn_lib_dir --cls=true \ ``` -##### 7. 表格识别 +##### 7. 版面分析+表格识别 +```shell +./build/ppocr --det_model_dir=inference/det_db \ + --rec_model_dir=inference/rec_rcnn \ + --table_model_dir=inference/table \ + --image_dir=../../ppstructure/docs/table/table.jpg \ + --layout_model_dir=inference/layout \ + --type=structure \ + --table=true \ + --layout=true +``` + +##### 8. 版面分析 +```shell +./build/ppocr --layout_model_dir=inference/layout \ + --image_dir=../../ppstructure/docs/table/1.png \ + --type=structure \ + --table=false \ + --layout=true \ + --det=false \ + --rec=false +``` + +##### 9. 表格识别 ```shell ./build/ppocr --det_model_dir=inference/det_db \ --rec_model_dir=inference/rec_rcnn \ @@ -352,13 +378,24 @@ CUDNN_LIB_DIR=/your_cudnn_lib_dir |rec_img_w|int|320|文字识别模型输入图像宽度| +- 版面分析模型相关 + +|参数名称|类型|默认参数|意义| +| :---: | :---: | :---: | :---: | +|layout_model_dir|string|-|版面分析模型inference model地址| +|layout_dict_path|string|../../ppocr/utils/dict/layout_dict/layout_publaynet_dict.txt|字典文件| +|layout_score_threshold|float|0.5|检测框的分数阈值| +|layout_nms_threshold|float|0.5|nms的阈值| + + - 表格识别模型相关 |参数名称|类型|默认参数|意义| | :---: | :---: | :---: | :---: | |table_model_dir|string|-|表格识别模型inference model地址| -|table_char_dict_path|string|../../ppocr/utils/dict/table_structure_dict.txt|字典文件| +|table_char_dict_path|string|../../ppocr/utils/dict/table_structure_dict_ch.txt|字典文件| |table_max_len|int|488|表格识别模型输入图像长边大小,最终网络输入图像大小为(table_max_len,table_max_len)| +|merge_no_span_structure|bool|true|是否合并 和 为| * PaddleOCR也支持多语言的预测,更多支持的语言和模型可以参考[识别文档](../../doc/doc_ch/recognition.md)中的多语言字典与模型部分,如果希望进行多语言预测,只需将修改`rec_char_dict_path`(字典文件路径)以及`rec_model_dir`(inference模型路径)字段即可。 @@ -377,11 +414,51 @@ predict img: ../../doc/imgs/12.jpg The detection visualized image saved in ./output//12.jpg ``` -- table +- layout+table ```bash -predict img: ../../ppstructure/docs/table/table.jpg -0 type: table, region: [0,0,371,293], res:
MethodsRPFFPS
SegLink [26]70.086.077.08.9
PixelLink [4]73.283.077.8-
TextSnake [18]73.983.278.31.1
TextField [37]75.987.481.35.2
MSR[38]76.787.481.7-
FTSN [3]77.187.682.0-
LSE[30]81.784.282.9-
CRAFT [2]78.288.282.98.6
MCN [16]798883-
ATRR[35]82.185.283.6-
PAN [34]83.884.484.130.2
DB[12]79.291.584.932.0
DRRG [41]82.3088.0585.08-
Ours (SynText)80.6885.4082.9712.68
Ours (MLT-17)84.5486.6285.5712.31
+predict img: ../../ppstructure/docs/table/1.png +0 type: text, region: [12,729,410,848], score: 0.781044, res: count of ocr result is : 7 +********** print ocr result ********** +0 det boxes: [[4,1],[79,1],[79,12],[4,12]] rec text: CTW1500. rec score: 0.769472 +... +6 det boxes: [[4,99],[391,99],[391,112],[4,112]] rec text: sate-of-the-artmethods[12.34.36l.ourapproachachieves rec score: 0.90414 +********** end print ocr result ********** +1 type: text, region: [69,342,342,359], score: 0.703666, res: count of ocr result is : 1 +********** print ocr result ********** +0 det boxes: [[8,2],[269,2],[269,13],[8,13]] rec text: Table6.Experimentalresults on CTW-1500 rec score: 0.890454 +********** end print ocr result ********** +2 type: text, region: [70,316,706,332], score: 0.659738, res: count of ocr result is : 2 +********** print ocr result ********** +0 det boxes: [[373,2],[630,2],[630,11],[373,11]] rec text: oroposals.andthegreencontoursarefinal rec score: 0.919729 +1 det boxes: [[8,3],[357,3],[357,11],[8,11]] rec text: Visualexperimentalresultshebluecontoursareboundar rec score: 0.915963 +********** end print ocr result ********** +3 type: text, region: [489,342,789,359], score: 0.630538, res: count of ocr result is : 1 +********** print ocr result ********** +0 det boxes: [[8,2],[294,2],[294,14],[8,14]] rec text: Table7.Experimentalresults onMSRA-TD500 rec score: 0.942251 +********** end print ocr result ********** +4 type: text, region: [444,751,841,848], score: 0.607345, res: count of ocr result is : 5 +********** print ocr result ********** +0 det boxes: [[19,3],[389,3],[389,17],[19,17]] rec text: Inthispaper,weproposeanovel adaptivebound rec score: 0.941031 +1 det boxes: [[4,22],[390,22],[390,36],[4,36]] rec text: aryproposalnetworkforarbitraryshapetextdetection rec score: 0.960172 +2 det boxes: [[4,42],[392,42],[392,56],[4,56]] rec text: whichadoptanboundaryproposalmodeltogeneratecoarse rec score: 0.934647 +3 det boxes: [[4,61],[389,61],[389,75],[4,75]] rec text: ooundaryproposals,andthenadoptanadaptiveboundary rec score: 0.946296 +4 det boxes: [[5,80],[387,80],[387,93],[5,93]] rec text: leformationmodelcombinedwithGCNandRNNtoper rec score: 0.952401 +********** end print ocr result ********** +5 type: title, region: [444,705,564,724], score: 0.785429, res: count of ocr result is : 1 +********** print ocr result ********** +0 det boxes: [[6,2],[113,2],[113,14],[6,14]] rec text: 5.Conclusion rec score: 0.856903 +********** end print ocr result ********** +6 type: table, region: [14,360,402,711], score: 0.963643, res:
MethodsExtRPFFPS
TextSnake [18]Syn85.367.975.6
CSE [17]MiLT76.178.777.40.38
LOMO[40]Syn76.585.780.84.4
ATRR[35]Sy-80.280.180.1-
SegLink++ [28]Syn79.882.881.3-
TextField [37]Syn79.883.081.46.0
MSR[38]Syn79.084.181.54.3
PSENet-1s [33]MLT79.784.882.23.9
DB [12]Syn80.286.983.422.0
CRAFT [2]Syn81.186.083.5-
TextDragon [5]MLT+82.884.583.6
PAN [34]Syn81.286.483.739.8
ContourNet [36]84.183.783.94.5
DRRG [41]MLT83.0285.9384.45-
TextPerception[23]Syn81.987.584.6
Ours Syn80.5787.6683.9712.08
Ours81.4587.8184.5112.15
OursMLT83.6086.4585.0012.21
+The table visualized image saved in ./output//6_1.png +7 type: table, region: [462,359,820,657], score: 0.953917, res:
MethodsRPFFPS
SegLink [26]70.086.077.08.9
PixelLink [4]73.283.077.8-
TextSnake [18]73.983.278.31.1
TextField [37]75.987.481.35.2
MSR[38]76.787.481.7-
FTSN[3]77.187.682.0:
LSE[30]81.784.282.9
CRAFT [2]78.288.282.98.6
MCN [16]798883-
ATRR[35]82.185.283.6-
PAN [34]83.884.484.130.2
DB[12]79.291.584.932.0
DRRG [41]82.3088.0585.08-
Ours (SynText)80.6885.4082.9712.68
Ours (MLT-17)84.5486.6285.5712.31
+The table visualized image saved in ./output//7_1.png +8 type: figure, region: [14,3,836,310], score: 0.969443, res: count of ocr result is : 26 +********** print ocr result ********** +0 det boxes: [[506,14],[539,15],[539,22],[506,21]] rec text: E rec score: 0.318073 +... +25 det boxes: [[680,290],[759,288],[759,303],[680,305]] rec text: (d) CTW1500 rec score: 0.95911 +********** end print ocr result ********** ``` diff --git a/deploy/cpp_infer/src/args.cpp b/deploy/cpp_infer/src/args.cpp index df1b9e32a3aacc309d6485114f9b267001f79920..28066f0b20061059f32e2658fa4ea70fd827acb7 100644 --- a/deploy/cpp_infer/src/args.cpp +++ b/deploy/cpp_infer/src/args.cpp @@ -51,16 +51,26 @@ DEFINE_string(rec_char_dict_path, "../../ppocr/utils/ppocr_keys_v1.txt", DEFINE_int32(rec_img_h, 48, "rec image height"); DEFINE_int32(rec_img_w, 320, "rec image width"); +// layout model related +DEFINE_string(layout_model_dir, "", "Path of table layout inference model."); +DEFINE_string(layout_dict_path, + "../../ppocr/utils/dict/layout_dict/layout_publaynet_dict.txt", + "Path of dictionary."); +DEFINE_double(layout_score_threshold, 0.5, "Threshold of score."); +DEFINE_double(layout_nms_threshold, 0.5, "Threshold of nms."); // structure model related DEFINE_string(table_model_dir, "", "Path of table struture inference model."); DEFINE_int32(table_max_len, 488, "max len size of input image."); DEFINE_int32(table_batch_num, 1, "table_batch_num."); +DEFINE_bool(merge_no_span_structure, true, + "Whether merge and to "); DEFINE_string(table_char_dict_path, - "../../ppocr/utils/dict/table_structure_dict.txt", + "../../ppocr/utils/dict/table_structure_dict_ch.txt", "Path of dictionary."); // ocr forward related DEFINE_bool(det, true, "Whether use det in forward."); DEFINE_bool(rec, true, "Whether use rec in forward."); DEFINE_bool(cls, false, "Whether use cls in forward."); -DEFINE_bool(table, false, "Whether use table structure in forward."); \ No newline at end of file +DEFINE_bool(table, false, "Whether use table structure in forward."); +DEFINE_bool(layout, false, "Whether use layout analysis in forward."); \ No newline at end of file diff --git a/deploy/cpp_infer/src/main.cpp b/deploy/cpp_infer/src/main.cpp index 66412a7b283f84107e117cfd59fb7d7aabff651c..0c155dd0eca04874d23c3be7e6eff241b73f5f1b 100644 --- a/deploy/cpp_infer/src/main.cpp +++ b/deploy/cpp_infer/src/main.cpp @@ -65,9 +65,18 @@ void check_params() { exit(1); } } + if (FLAGS_layout) { + if (FLAGS_layout_model_dir.empty() || FLAGS_image_dir.empty()) { + std::cout << "Usage[layout]: ./ppocr " + << "--layout_model_dir=/PATH/TO/LAYOUT_INFERENCE_MODEL/ " + << "--image_dir=/PATH/TO/INPUT/IMAGE/" << std::endl; + exit(1); + } + } if (FLAGS_precision != "fp32" && FLAGS_precision != "fp16" && FLAGS_precision != "int8") { - cout << "precison should be 'fp32'(default), 'fp16' or 'int8'. " << endl; + std::cout << "precison should be 'fp32'(default), 'fp16' or 'int8'. " + << std::endl; exit(1); } } @@ -75,65 +84,94 @@ void check_params() { void ocr(std::vector &cv_all_img_names) { PPOCR ocr = PPOCR(); - std::vector> ocr_results = - ocr.ocr(cv_all_img_names, FLAGS_det, FLAGS_rec, FLAGS_cls); + if (FLAGS_benchmark) { + ocr.reset_timer(); + } + std::vector img_list; + std::vector img_names; for (int i = 0; i < cv_all_img_names.size(); ++i) { - if (FLAGS_benchmark) { - cout << cv_all_img_names[i] << '\t'; - if (FLAGS_rec && FLAGS_det) { - Utility::print_result(ocr_results[i]); - } else if (FLAGS_det) { - for (int n = 0; n < ocr_results[i].size(); n++) { - for (int m = 0; m < ocr_results[i][n].box.size(); m++) { - cout << ocr_results[i][n].box[m][0] << ' ' - << ocr_results[i][n].box[m][1] << ' '; - } - } - cout << endl; - } else { - Utility::print_result(ocr_results[i]); - } - } else { - cout << cv_all_img_names[i] << "\n"; - Utility::print_result(ocr_results[i]); - if (FLAGS_visualize && FLAGS_det) { - cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR); - if (!srcimg.data) { - std::cerr << "[ERROR] image read failed! image path: " - << cv_all_img_names[i] << endl; - exit(1); - } - std::string file_name = Utility::basename(cv_all_img_names[i]); + cv::Mat img = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR); + if (!img.data) { + std::cerr << "[ERROR] image read failed! image path: " + << cv_all_img_names[i] << std::endl; + continue; + } + img_list.push_back(img); + img_names.push_back(cv_all_img_names[i]); + } - Utility::VisualizeBboxes(srcimg, ocr_results[i], - FLAGS_output + "/" + file_name); - } - cout << "***************************" << endl; + std::vector> ocr_results = + ocr.ocr(img_list, FLAGS_det, FLAGS_rec, FLAGS_cls); + + for (int i = 0; i < img_names.size(); ++i) { + std::cout << "predict img: " << cv_all_img_names[i] << std::endl; + Utility::print_result(ocr_results[i]); + if (FLAGS_visualize && FLAGS_det) { + std::string file_name = Utility::basename(img_names[i]); + cv::Mat srcimg = img_list[i]; + Utility::VisualizeBboxes(srcimg, ocr_results[i], + FLAGS_output + "/" + file_name); } } + if (FLAGS_benchmark) { + ocr.benchmark_log(cv_all_img_names.size()); + } } void structure(std::vector &cv_all_img_names) { PaddleOCR::PaddleStructure engine = PaddleOCR::PaddleStructure(); - std::vector> structure_results = - engine.structure(cv_all_img_names, false, FLAGS_table); + + if (FLAGS_benchmark) { + engine.reset_timer(); + } + for (int i = 0; i < cv_all_img_names.size(); i++) { - cout << "predict img: " << cv_all_img_names[i] << endl; - for (int j = 0; j < structure_results[i].size(); j++) { - std::cout << j << "\ttype: " << structure_results[i][j].type + std::cout << "predict img: " << cv_all_img_names[i] << std::endl; + cv::Mat img = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR); + if (!img.data) { + std::cerr << "[ERROR] image read failed! image path: " + << cv_all_img_names[i] << std::endl; + continue; + } + + std::vector structure_results = engine.structure( + img, FLAGS_layout, FLAGS_table, FLAGS_det && FLAGS_rec); + + for (int j = 0; j < structure_results.size(); j++) { + std::cout << j << "\ttype: " << structure_results[j].type << ", region: ["; - std::cout << structure_results[i][j].box[0] << "," - << structure_results[i][j].box[1] << "," - << structure_results[i][j].box[2] << "," - << structure_results[i][j].box[3] << "], res: "; - if (structure_results[i][j].type == "table") { - std::cout << structure_results[i][j].html << std::endl; + std::cout << structure_results[j].box[0] << "," + << structure_results[j].box[1] << "," + << structure_results[j].box[2] << "," + << structure_results[j].box[3] << "], score: "; + std::cout << structure_results[j].confidence << ", res: "; + + if (structure_results[j].type == "table") { + std::cout << structure_results[j].html << std::endl; + if (structure_results[j].cell_box.size() > 0 && FLAGS_visualize) { + std::string file_name = Utility::basename(cv_all_img_names[i]); + + Utility::VisualizeBboxes(img, structure_results[j], + FLAGS_output + "/" + std::to_string(j) + + "_" + file_name); + } } else { - Utility::print_result(structure_results[i][j].text_res); + std::cout << "count of ocr result is : " + << structure_results[j].text_res.size() << std::endl; + if (structure_results[j].text_res.size() > 0) { + std::cout << "********** print ocr result " + << "**********" << std::endl; + Utility::print_result(structure_results[j].text_res); + std::cout << "********** end print ocr result " + << "**********" << std::endl; + } } } } + if (FLAGS_benchmark) { + engine.benchmark_log(cv_all_img_names.size()); + } } int main(int argc, char **argv) { @@ -143,19 +181,22 @@ int main(int argc, char **argv) { if (!Utility::PathExists(FLAGS_image_dir)) { std::cerr << "[ERROR] image path not exist! image_dir: " << FLAGS_image_dir - << endl; + << std::endl; exit(1); } std::vector cv_all_img_names; cv::glob(FLAGS_image_dir, cv_all_img_names); - std::cout << "total images num: " << cv_all_img_names.size() << endl; + std::cout << "total images num: " << cv_all_img_names.size() << std::endl; + if (!Utility::PathExists(FLAGS_output)) { + Utility::CreateDir(FLAGS_output); + } if (FLAGS_type == "ocr") { ocr(cv_all_img_names); } else if (FLAGS_type == "structure") { structure(cv_all_img_names); } else { - std::cout << "only value in ['ocr','structure'] is supported" << endl; + std::cout << "only value in ['ocr','structure'] is supported" << std::endl; } } diff --git a/deploy/cpp_infer/src/ocr_cls.cpp b/deploy/cpp_infer/src/ocr_cls.cpp index 674630bf1e7e04841e027a7320d62af4a453ffc8..abcfed125f45253fc13c72f94621dda25ba12780 100644 --- a/deploy/cpp_infer/src/ocr_cls.cpp +++ b/deploy/cpp_infer/src/ocr_cls.cpp @@ -32,7 +32,7 @@ void Classifier::Run(std::vector img_list, for (int beg_img_no = 0; beg_img_no < img_num; beg_img_no += this->cls_batch_num_) { auto preprocess_start = std::chrono::steady_clock::now(); - int end_img_no = min(img_num, beg_img_no + this->cls_batch_num_); + int end_img_no = std::min(img_num, beg_img_no + this->cls_batch_num_); int batch_num = end_img_no - beg_img_no; // preprocess std::vector norm_img_batch; @@ -97,7 +97,7 @@ void Classifier::Run(std::vector img_list, } void Classifier::LoadModel(const std::string &model_dir) { - AnalysisConfig config; + paddle_infer::Config config; config.SetModel(model_dir + "/inference.pdmodel", model_dir + "/inference.pdiparams"); @@ -112,6 +112,11 @@ void Classifier::LoadModel(const std::string &model_dir) { precision = paddle_infer::Config::Precision::kInt8; } config.EnableTensorRtEngine(1 << 20, 10, 3, precision, false, false); + if (!Utility::PathExists("./trt_cls_shape.txt")) { + config.CollectShapeRangeInfo("./trt_cls_shape.txt"); + } else { + config.EnableTunedTensorRtDynamicShape("./trt_cls_shape.txt", true); + } } } else { config.DisableGpu(); @@ -131,6 +136,6 @@ void Classifier::LoadModel(const std::string &model_dir) { config.EnableMemoryOptim(); config.DisableGlogInfo(); - this->predictor_ = CreatePredictor(config); + this->predictor_ = paddle_infer::CreatePredictor(config); } } // namespace PaddleOCR diff --git a/deploy/cpp_infer/src/ocr_det.cpp b/deploy/cpp_infer/src/ocr_det.cpp index 56de195186a0d4d6c8b2482eb57c106347485928..74fa09bed1193a89091dca82569fa256d1773433 100644 --- a/deploy/cpp_infer/src/ocr_det.cpp +++ b/deploy/cpp_infer/src/ocr_det.cpp @@ -32,49 +32,12 @@ void DBDetector::LoadModel(const std::string &model_dir) { if (this->precision_ == "int8") { precision = paddle_infer::Config::Precision::kInt8; } - config.EnableTensorRtEngine(1 << 20, 1, 20, precision, false, false); - std::map> min_input_shape = { - {"x", {1, 3, 50, 50}}, - {"conv2d_92.tmp_0", {1, 120, 20, 20}}, - {"conv2d_91.tmp_0", {1, 24, 10, 10}}, - {"conv2d_59.tmp_0", {1, 96, 20, 20}}, - {"nearest_interp_v2_1.tmp_0", {1, 256, 10, 10}}, - {"nearest_interp_v2_2.tmp_0", {1, 256, 20, 20}}, - {"conv2d_124.tmp_0", {1, 256, 20, 20}}, - {"nearest_interp_v2_3.tmp_0", {1, 64, 20, 20}}, - {"nearest_interp_v2_4.tmp_0", {1, 64, 20, 20}}, - {"nearest_interp_v2_5.tmp_0", {1, 64, 20, 20}}, - {"elementwise_add_7", {1, 56, 2, 2}}, - {"nearest_interp_v2_0.tmp_0", {1, 256, 2, 2}}}; - std::map> max_input_shape = { - {"x", {1, 3, 1536, 1536}}, - {"conv2d_92.tmp_0", {1, 120, 400, 400}}, - {"conv2d_91.tmp_0", {1, 24, 200, 200}}, - {"conv2d_59.tmp_0", {1, 96, 400, 400}}, - {"nearest_interp_v2_1.tmp_0", {1, 256, 200, 200}}, - {"nearest_interp_v2_2.tmp_0", {1, 256, 400, 400}}, - {"conv2d_124.tmp_0", {1, 256, 400, 400}}, - {"nearest_interp_v2_3.tmp_0", {1, 64, 400, 400}}, - {"nearest_interp_v2_4.tmp_0", {1, 64, 400, 400}}, - {"nearest_interp_v2_5.tmp_0", {1, 64, 400, 400}}, - {"elementwise_add_7", {1, 56, 400, 400}}, - {"nearest_interp_v2_0.tmp_0", {1, 256, 400, 400}}}; - std::map> opt_input_shape = { - {"x", {1, 3, 640, 640}}, - {"conv2d_92.tmp_0", {1, 120, 160, 160}}, - {"conv2d_91.tmp_0", {1, 24, 80, 80}}, - {"conv2d_59.tmp_0", {1, 96, 160, 160}}, - {"nearest_interp_v2_1.tmp_0", {1, 256, 80, 80}}, - {"nearest_interp_v2_2.tmp_0", {1, 256, 160, 160}}, - {"conv2d_124.tmp_0", {1, 256, 160, 160}}, - {"nearest_interp_v2_3.tmp_0", {1, 64, 160, 160}}, - {"nearest_interp_v2_4.tmp_0", {1, 64, 160, 160}}, - {"nearest_interp_v2_5.tmp_0", {1, 64, 160, 160}}, - {"elementwise_add_7", {1, 56, 40, 40}}, - {"nearest_interp_v2_0.tmp_0", {1, 256, 40, 40}}}; - - config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape, - opt_input_shape); + config.EnableTensorRtEngine(1 << 30, 1, 20, precision, false, false); + if (!Utility::PathExists("./trt_det_shape.txt")) { + config.CollectShapeRangeInfo("./trt_det_shape.txt"); + } else { + config.EnableTunedTensorRtDynamicShape("./trt_det_shape.txt", true); + } } } else { config.DisableGpu(); @@ -95,7 +58,7 @@ void DBDetector::LoadModel(const std::string &model_dir) { config.EnableMemoryOptim(); // config.DisableGlogInfo(); - this->predictor_ = CreatePredictor(config); + this->predictor_ = paddle_infer::CreatePredictor(config); } void DBDetector::Run(cv::Mat &img, diff --git a/deploy/cpp_infer/src/ocr_rec.cpp b/deploy/cpp_infer/src/ocr_rec.cpp index 0f90ddfab4872f97829da081e64cb7437e72493a..96715163681092c0075fdbf456cc38b1679d82b9 100644 --- a/deploy/cpp_infer/src/ocr_rec.cpp +++ b/deploy/cpp_infer/src/ocr_rec.cpp @@ -37,7 +37,7 @@ void CRNNRecognizer::Run(std::vector img_list, for (int beg_img_no = 0; beg_img_no < img_num; beg_img_no += this->rec_batch_num_) { auto preprocess_start = std::chrono::steady_clock::now(); - int end_img_no = min(img_num, beg_img_no + this->rec_batch_num_); + int end_img_no = std::min(img_num, beg_img_no + this->rec_batch_num_); int batch_num = end_img_no - beg_img_no; int imgH = this->rec_image_shape_[1]; int imgW = this->rec_image_shape_[2]; @@ -46,7 +46,7 @@ void CRNNRecognizer::Run(std::vector img_list, int h = img_list[indices[ino]].rows; int w = img_list[indices[ino]].cols; float wh_ratio = w * 1.0 / h; - max_wh_ratio = max(max_wh_ratio, wh_ratio); + max_wh_ratio = std::max(max_wh_ratio, wh_ratio); } int batch_width = imgW; @@ -60,7 +60,7 @@ void CRNNRecognizer::Run(std::vector img_list, this->normalize_op_.Run(&resize_img, this->mean_, this->scale_, this->is_scale_); norm_img_batch.push_back(resize_img); - batch_width = max(resize_img.cols, batch_width); + batch_width = std::max(resize_img.cols, batch_width); } std::vector input(batch_num * 3 * imgH * batch_width, 0.0f); @@ -115,7 +115,7 @@ void CRNNRecognizer::Run(std::vector img_list, last_index = argmax_idx; } score /= count; - if (isnan(score)) { + if (std::isnan(score)) { continue; } rec_texts[indices[beg_img_no + m]] = str_res; @@ -130,7 +130,6 @@ void CRNNRecognizer::Run(std::vector img_list, } void CRNNRecognizer::LoadModel(const std::string &model_dir) { - // AnalysisConfig config; paddle_infer::Config config; config.SetModel(model_dir + "/inference.pdmodel", model_dir + "/inference.pdiparams"); @@ -147,20 +146,11 @@ void CRNNRecognizer::LoadModel(const std::string &model_dir) { if (this->precision_ == "int8") { precision = paddle_infer::Config::Precision::kInt8; } - config.EnableTensorRtEngine(1 << 20, 10, 15, precision, false, false); - int imgH = this->rec_image_shape_[1]; - int imgW = this->rec_image_shape_[2]; - std::map> min_input_shape = { - {"x", {1, 3, imgH, 10}}, {"lstm_0.tmp_0", {10, 1, 96}}}; - std::map> max_input_shape = { - {"x", {this->rec_batch_num_, 3, imgH, 2500}}, - {"lstm_0.tmp_0", {1000, 1, 96}}}; - std::map> opt_input_shape = { - {"x", {this->rec_batch_num_, 3, imgH, imgW}}, - {"lstm_0.tmp_0", {25, 1, 96}}}; - - config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape, - opt_input_shape); + if (!Utility::PathExists("./trt_rec_shape.txt")) { + config.CollectShapeRangeInfo("./trt_rec_shape.txt"); + } else { + config.EnableTunedTensorRtDynamicShape("./trt_rec_shape.txt", true); + } } } else { config.DisableGpu(); @@ -185,7 +175,7 @@ void CRNNRecognizer::LoadModel(const std::string &model_dir) { config.EnableMemoryOptim(); // config.DisableGlogInfo(); - this->predictor_ = CreatePredictor(config); + this->predictor_ = paddle_infer::CreatePredictor(config); } } // namespace PaddleOCR diff --git a/deploy/cpp_infer/src/paddleocr.cpp b/deploy/cpp_infer/src/paddleocr.cpp index 1de4fc7e9af8bf63cf68ef42d2a508cdc4b5f9f3..86747c60d682c4f2df66a8bc8f5c9dae68b80170 100644 --- a/deploy/cpp_infer/src/paddleocr.cpp +++ b/deploy/cpp_infer/src/paddleocr.cpp @@ -16,7 +16,7 @@ #include #include "auto_log/autolog.h" -#include + namespace PaddleOCR { PPOCR::PPOCR() { @@ -44,8 +44,71 @@ PPOCR::PPOCR() { } }; -void PPOCR::det(cv::Mat img, std::vector &ocr_results, - std::vector ×) { +std::vector> +PPOCR::ocr(std::vector img_list, bool det, bool rec, bool cls) { + std::vector> ocr_results; + + if (!det) { + std::vector ocr_result; + ocr_result.resize(img_list.size()); + if (cls && this->classifier_ != nullptr) { + this->cls(img_list, ocr_result); + for (int i = 0; i < img_list.size(); i++) { + if (ocr_result[i].cls_label % 2 == 1 && + ocr_result[i].cls_score > this->classifier_->cls_thresh) { + cv::rotate(img_list[i], img_list[i], 1); + } + } + } + if (rec) { + this->rec(img_list, ocr_result); + } + for (int i = 0; i < ocr_result.size(); ++i) { + std::vector ocr_result_tmp; + ocr_result_tmp.push_back(ocr_result[i]); + ocr_results.push_back(ocr_result_tmp); + } + } else { + for (int i = 0; i < img_list.size(); ++i) { + std::vector ocr_result = + this->ocr(img_list[i], true, rec, cls); + ocr_results.push_back(ocr_result); + } + } + return ocr_results; +} + +std::vector PPOCR::ocr(cv::Mat img, bool det, bool rec, + bool cls) { + + std::vector ocr_result; + // det + this->det(img, ocr_result); + // crop image + std::vector img_list; + for (int j = 0; j < ocr_result.size(); j++) { + cv::Mat crop_img; + crop_img = Utility::GetRotateCropImage(img, ocr_result[j].box); + img_list.push_back(crop_img); + } + // cls + if (cls && this->classifier_ != nullptr) { + this->cls(img_list, ocr_result); + for (int i = 0; i < img_list.size(); i++) { + if (ocr_result[i].cls_label % 2 == 1 && + ocr_result[i].cls_score > this->classifier_->cls_thresh) { + cv::rotate(img_list[i], img_list[i], 1); + } + } + } + // rec + if (rec) { + this->rec(img_list, ocr_result); + } + return ocr_result; +} + +void PPOCR::det(cv::Mat img, std::vector &ocr_results) { std::vector>> boxes; std::vector det_times; @@ -58,14 +121,13 @@ void PPOCR::det(cv::Mat img, std::vector &ocr_results, } // sort boex from top to bottom, from left to right Utility::sorted_boxes(ocr_results); - times[0] += det_times[0]; - times[1] += det_times[1]; - times[2] += det_times[2]; + this->time_info_det[0] += det_times[0]; + this->time_info_det[1] += det_times[1]; + this->time_info_det[2] += det_times[2]; } void PPOCR::rec(std::vector img_list, - std::vector &ocr_results, - std::vector ×) { + std::vector &ocr_results) { std::vector rec_texts(img_list.size(), ""); std::vector rec_text_scores(img_list.size(), 0); std::vector rec_times; @@ -75,14 +137,13 @@ void PPOCR::rec(std::vector img_list, ocr_results[i].text = rec_texts[i]; ocr_results[i].score = rec_text_scores[i]; } - times[0] += rec_times[0]; - times[1] += rec_times[1]; - times[2] += rec_times[2]; + this->time_info_rec[0] += rec_times[0]; + this->time_info_rec[1] += rec_times[1]; + this->time_info_rec[2] += rec_times[2]; } void PPOCR::cls(std::vector img_list, - std::vector &ocr_results, - std::vector ×) { + std::vector &ocr_results) { std::vector cls_labels(img_list.size(), 0); std::vector cls_scores(img_list.size(), 0); std::vector cls_times; @@ -92,125 +153,43 @@ void PPOCR::cls(std::vector img_list, ocr_results[i].cls_label = cls_labels[i]; ocr_results[i].cls_score = cls_scores[i]; } - times[0] += cls_times[0]; - times[1] += cls_times[1]; - times[2] += cls_times[2]; + this->time_info_cls[0] += cls_times[0]; + this->time_info_cls[1] += cls_times[1]; + this->time_info_cls[2] += cls_times[2]; } -std::vector> -PPOCR::ocr(std::vector cv_all_img_names, bool det, bool rec, - bool cls) { - std::vector time_info_det = {0, 0, 0}; - std::vector time_info_rec = {0, 0, 0}; - std::vector time_info_cls = {0, 0, 0}; - std::vector> ocr_results; - - if (!det) { - std::vector ocr_result; - // read image - std::vector img_list; - for (int i = 0; i < cv_all_img_names.size(); ++i) { - cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR); - if (!srcimg.data) { - std::cerr << "[ERROR] image read failed! image path: " - << cv_all_img_names[i] << endl; - exit(1); - } - img_list.push_back(srcimg); - OCRPredictResult res; - ocr_result.push_back(res); - } - if (cls && this->classifier_ != nullptr) { - this->cls(img_list, ocr_result, time_info_cls); - for (int i = 0; i < img_list.size(); i++) { - if (ocr_result[i].cls_label % 2 == 1 && - ocr_result[i].cls_score > this->classifier_->cls_thresh) { - cv::rotate(img_list[i], img_list[i], 1); - } - } - } - if (rec) { - this->rec(img_list, ocr_result, time_info_rec); - } - for (int i = 0; i < cv_all_img_names.size(); ++i) { - std::vector ocr_result_tmp; - ocr_result_tmp.push_back(ocr_result[i]); - ocr_results.push_back(ocr_result_tmp); - } - } else { - if (!Utility::PathExists(FLAGS_output) && FLAGS_det) { - Utility::CreateDir(FLAGS_output); - } - - for (int i = 0; i < cv_all_img_names.size(); ++i) { - std::vector ocr_result; - if (!FLAGS_benchmark) { - cout << "predict img: " << cv_all_img_names[i] << endl; - } - - cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR); - if (!srcimg.data) { - std::cerr << "[ERROR] image read failed! image path: " - << cv_all_img_names[i] << endl; - exit(1); - } - // det - this->det(srcimg, ocr_result, time_info_det); - // crop image - std::vector img_list; - for (int j = 0; j < ocr_result.size(); j++) { - cv::Mat crop_img; - crop_img = Utility::GetRotateCropImage(srcimg, ocr_result[j].box); - img_list.push_back(crop_img); - } - - // cls - if (cls && this->classifier_ != nullptr) { - this->cls(img_list, ocr_result, time_info_cls); - for (int i = 0; i < img_list.size(); i++) { - if (ocr_result[i].cls_label % 2 == 1 && - ocr_result[i].cls_score > this->classifier_->cls_thresh) { - cv::rotate(img_list[i], img_list[i], 1); - } - } - } - // rec - if (rec) { - this->rec(img_list, ocr_result, time_info_rec); - } - ocr_results.push_back(ocr_result); - } - } - if (FLAGS_benchmark) { - this->log(time_info_det, time_info_rec, time_info_cls, - cv_all_img_names.size()); - } - return ocr_results; -} // namespace PaddleOCR +void PPOCR::reset_timer() { + this->time_info_det = {0, 0, 0}; + this->time_info_rec = {0, 0, 0}; + this->time_info_cls = {0, 0, 0}; +} -void PPOCR::log(std::vector &det_times, std::vector &rec_times, - std::vector &cls_times, int img_num) { - if (det_times[0] + det_times[1] + det_times[2] > 0) { +void PPOCR::benchmark_log(int img_num) { + if (this->time_info_det[0] + this->time_info_det[1] + this->time_info_det[2] > + 0) { AutoLogger autolog_det("ocr_det", FLAGS_use_gpu, FLAGS_use_tensorrt, FLAGS_enable_mkldnn, FLAGS_cpu_threads, 1, "dynamic", - FLAGS_precision, det_times, img_num); + FLAGS_precision, this->time_info_det, img_num); autolog_det.report(); } - if (rec_times[0] + rec_times[1] + rec_times[2] > 0) { + if (this->time_info_rec[0] + this->time_info_rec[1] + this->time_info_rec[2] > + 0) { AutoLogger autolog_rec("ocr_rec", FLAGS_use_gpu, FLAGS_use_tensorrt, FLAGS_enable_mkldnn, FLAGS_cpu_threads, FLAGS_rec_batch_num, "dynamic", FLAGS_precision, - rec_times, img_num); + this->time_info_rec, img_num); autolog_rec.report(); } - if (cls_times[0] + cls_times[1] + cls_times[2] > 0) { + if (this->time_info_cls[0] + this->time_info_cls[1] + this->time_info_cls[2] > + 0) { AutoLogger autolog_cls("ocr_cls", FLAGS_use_gpu, FLAGS_use_tensorrt, FLAGS_enable_mkldnn, FLAGS_cpu_threads, FLAGS_cls_batch_num, "dynamic", FLAGS_precision, - cls_times, img_num); + this->time_info_cls, img_num); autolog_cls.report(); } } + PPOCR::~PPOCR() { if (this->detector_ != nullptr) { delete this->detector_; diff --git a/deploy/cpp_infer/src/paddlestructure.cpp b/deploy/cpp_infer/src/paddlestructure.cpp index 1ca85a96bbcf09472ce5916375a24a9441a2da53..b2e35f8c777bde3cea0a3fefd0ce8517d8d75318 100644 --- a/deploy/cpp_infer/src/paddlestructure.cpp +++ b/deploy/cpp_infer/src/paddlestructure.cpp @@ -16,83 +16,83 @@ #include #include "auto_log/autolog.h" -#include -#include namespace PaddleOCR { PaddleStructure::PaddleStructure() { + if (FLAGS_layout) { + this->layout_model_ = new StructureLayoutRecognizer( + FLAGS_layout_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem, + FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_layout_dict_path, + FLAGS_use_tensorrt, FLAGS_precision, FLAGS_layout_score_threshold, + FLAGS_layout_nms_threshold); + } if (FLAGS_table) { - this->recognizer_ = new StructureTableRecognizer( + this->table_model_ = new StructureTableRecognizer( FLAGS_table_model_dir, FLAGS_use_gpu, FLAGS_gpu_id, FLAGS_gpu_mem, FLAGS_cpu_threads, FLAGS_enable_mkldnn, FLAGS_table_char_dict_path, FLAGS_use_tensorrt, FLAGS_precision, FLAGS_table_batch_num, - FLAGS_table_max_len); + FLAGS_table_max_len, FLAGS_merge_no_span_structure); } }; -std::vector> -PaddleStructure::structure(std::vector cv_all_img_names, - bool layout, bool table) { - std::vector time_info_det = {0, 0, 0}; - std::vector time_info_rec = {0, 0, 0}; - std::vector time_info_cls = {0, 0, 0}; - std::vector time_info_table = {0, 0, 0}; +std::vector +PaddleStructure::structure(cv::Mat srcimg, bool layout, bool table, bool ocr) { + cv::Mat img; + srcimg.copyTo(img); - std::vector> structure_results; + std::vector structure_results; - if (!Utility::PathExists(FLAGS_output) && FLAGS_det) { - mkdir(FLAGS_output.c_str(), 0777); + if (layout) { + this->layout(img, structure_results); + } else { + StructurePredictResult res; + res.type = "table"; + res.box = std::vector(4, 0.0); + res.box[2] = img.cols; + res.box[3] = img.rows; + structure_results.push_back(res); } - for (int i = 0; i < cv_all_img_names.size(); ++i) { - std::vector structure_result; - cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR); - if (!srcimg.data) { - std::cerr << "[ERROR] image read failed! image path: " - << cv_all_img_names[i] << endl; - exit(1); - } - if (layout) { - } else { - StructurePredictResult res; - res.type = "table"; - res.box = std::vector(4, 0); - res.box[2] = srcimg.cols; - res.box[3] = srcimg.rows; - structure_result.push_back(res); - } - cv::Mat roi_img; - for (int i = 0; i < structure_result.size(); i++) { - // crop image - roi_img = Utility::crop_image(srcimg, structure_result[i].box); - if (structure_result[i].type == "table") { - this->table(roi_img, structure_result[i], time_info_table, - time_info_det, time_info_rec, time_info_cls); - } + cv::Mat roi_img; + for (int i = 0; i < structure_results.size(); i++) { + // crop image + roi_img = Utility::crop_image(img, structure_results[i].box); + if (structure_results[i].type == "table" && table) { + this->table(roi_img, structure_results[i]); + } else if (ocr) { + structure_results[i].text_res = this->ocr(roi_img, true, true, false); } - structure_results.push_back(structure_result); } + return structure_results; }; +void PaddleStructure::layout( + cv::Mat img, std::vector &structure_result) { + std::vector layout_times; + this->layout_model_->Run(img, structure_result, layout_times); + + this->time_info_layout[0] += layout_times[0]; + this->time_info_layout[1] += layout_times[1]; + this->time_info_layout[2] += layout_times[2]; +} + void PaddleStructure::table(cv::Mat img, - StructurePredictResult &structure_result, - std::vector &time_info_table, - std::vector &time_info_det, - std::vector &time_info_rec, - std::vector &time_info_cls) { + StructurePredictResult &structure_result) { // predict structure std::vector> structure_html_tags; std::vector structure_scores(1, 0); - std::vector>>> structure_boxes; - std::vector structure_imes; + std::vector>> structure_boxes; + std::vector structure_times; std::vector img_list; img_list.push_back(img); - this->recognizer_->Run(img_list, structure_html_tags, structure_scores, - structure_boxes, structure_imes); - time_info_table[0] += structure_imes[0]; - time_info_table[1] += structure_imes[1]; - time_info_table[2] += structure_imes[2]; + + this->table_model_->Run(img_list, structure_html_tags, structure_scores, + structure_boxes, structure_times); + + this->time_info_table[0] += structure_times[0]; + this->time_info_table[1] += structure_times[1]; + this->time_info_table[2] += structure_times[2]; std::vector ocr_result; std::string html; @@ -100,63 +100,57 @@ void PaddleStructure::table(cv::Mat img, for (int i = 0; i < img_list.size(); i++) { // det - this->det(img_list[i], ocr_result, time_info_det); + this->det(img_list[i], ocr_result); // crop image std::vector rec_img_list; + std::vector ocr_box; for (int j = 0; j < ocr_result.size(); j++) { - int x_collect[4] = {ocr_result[j].box[0][0], ocr_result[j].box[1][0], - ocr_result[j].box[2][0], ocr_result[j].box[3][0]}; - int y_collect[4] = {ocr_result[j].box[0][1], ocr_result[j].box[1][1], - ocr_result[j].box[2][1], ocr_result[j].box[3][1]}; - int left = int(*std::min_element(x_collect, x_collect + 4)); - int right = int(*std::max_element(x_collect, x_collect + 4)); - int top = int(*std::min_element(y_collect, y_collect + 4)); - int bottom = int(*std::max_element(y_collect, y_collect + 4)); - std::vector box{max(0, left - expand_pixel), - max(0, top - expand_pixel), - min(img_list[i].cols, right + expand_pixel), - min(img_list[i].rows, bottom + expand_pixel)}; - cv::Mat crop_img = Utility::crop_image(img_list[i], box); + ocr_box = Utility::xyxyxyxy2xyxy(ocr_result[j].box); + ocr_box[0] = std::max(0, ocr_box[0] - expand_pixel); + ocr_box[1] = std::max(0, ocr_box[1] - expand_pixel), + ocr_box[2] = std::min(img_list[i].cols, ocr_box[2] + expand_pixel); + ocr_box[3] = std::min(img_list[i].rows, ocr_box[3] + expand_pixel); + + cv::Mat crop_img = Utility::crop_image(img_list[i], ocr_box); rec_img_list.push_back(crop_img); } // rec - this->rec(rec_img_list, ocr_result, time_info_rec); + this->rec(rec_img_list, ocr_result); // rebuild table html = this->rebuild_table(structure_html_tags[i], structure_boxes[i], ocr_result); structure_result.html = html; + structure_result.cell_box = structure_boxes[i]; structure_result.html_score = structure_scores[i]; } }; -std::string PaddleStructure::rebuild_table( - std::vector structure_html_tags, - std::vector>> structure_boxes, - std::vector &ocr_result) { +std::string +PaddleStructure::rebuild_table(std::vector structure_html_tags, + std::vector> structure_boxes, + std::vector &ocr_result) { // match text in same cell - std::vector> matched(structure_boxes.size(), - std::vector()); + std::vector> matched(structure_boxes.size(), + std::vector()); + std::vector ocr_box; + std::vector structure_box; for (int i = 0; i < ocr_result.size(); i++) { + ocr_box = Utility::xyxyxyxy2xyxy(ocr_result[i].box); + ocr_box[0] -= 1; + ocr_box[1] -= 1; + ocr_box[2] += 1; + ocr_box[3] += 1; std::vector> dis_list(structure_boxes.size(), std::vector(3, 100000.0)); for (int j = 0; j < structure_boxes.size(); j++) { - int x_collect[4] = {ocr_result[i].box[0][0], ocr_result[i].box[1][0], - ocr_result[i].box[2][0], ocr_result[i].box[3][0]}; - int y_collect[4] = {ocr_result[i].box[0][1], ocr_result[i].box[1][1], - ocr_result[i].box[2][1], ocr_result[i].box[3][1]}; - int left = int(*std::min_element(x_collect, x_collect + 4)); - int right = int(*std::max_element(x_collect, x_collect + 4)); - int top = int(*std::min_element(y_collect, y_collect + 4)); - int bottom = int(*std::max_element(y_collect, y_collect + 4)); - std::vector> box(2, std::vector(2, 0)); - box[0][0] = left - 1; - box[0][1] = top - 1; - box[1][0] = right + 1; - box[1][1] = bottom + 1; - - dis_list[j][0] = this->dis(box, structure_boxes[j]); - dis_list[j][1] = 1 - this->iou(box, structure_boxes[j]); + if (structure_boxes[i].size() == 8) { + structure_box = Utility::xyxyxyxy2xyxy(structure_boxes[j]); + } else { + structure_box = structure_boxes[j]; + } + dis_list[j][0] = this->dis(ocr_box, structure_box); + dis_list[j][1] = 1 - Utility::iou(ocr_box, structure_box); dis_list[j][2] = j; } // find min dis idx @@ -164,6 +158,7 @@ std::string PaddleStructure::rebuild_table( PaddleStructure::comparison_dis); matched[dis_list[0][2]].push_back(ocr_result[i].text); } + // get pred html std::string html_str = ""; int td_tag_idx = 0; @@ -221,51 +216,79 @@ std::string PaddleStructure::rebuild_table( return html_str; } -float PaddleStructure::iou(std::vector> &box1, - std::vector> &box2) { - int area1 = max(0, box1[1][0] - box1[0][0]) * max(0, box1[1][1] - box1[0][1]); - int area2 = max(0, box2[1][0] - box2[0][0]) * max(0, box2[1][1] - box2[0][1]); - - // computing the sum_area - int sum_area = area1 + area2; +float PaddleStructure::dis(std::vector &box1, std::vector &box2) { + int x1_1 = box1[0]; + int y1_1 = box1[1]; + int x2_1 = box1[2]; + int y2_1 = box1[3]; - // find the each point of intersect rectangle - int x1 = max(box1[0][0], box2[0][0]); - int y1 = max(box1[0][1], box2[0][1]); - int x2 = min(box1[1][0], box2[1][0]); - int y2 = min(box1[1][1], box2[1][1]); - - // judge if there is an intersect - if (y1 >= y2 || x1 >= x2) { - return 0.0; - } else { - int intersect = (x2 - x1) * (y2 - y1); - return intersect / (sum_area - intersect + 0.00000001); - } -} - -float PaddleStructure::dis(std::vector> &box1, - std::vector> &box2) { - int x1_1 = box1[0][0]; - int y1_1 = box1[0][1]; - int x2_1 = box1[1][0]; - int y2_1 = box1[1][1]; - - int x1_2 = box2[0][0]; - int y1_2 = box2[0][1]; - int x2_2 = box2[1][0]; - int y2_2 = box2[1][1]; + int x1_2 = box2[0]; + int y1_2 = box2[1]; + int x2_2 = box2[2]; + int y2_2 = box2[3]; float dis = abs(x1_2 - x1_1) + abs(y1_2 - y1_1) + abs(x2_2 - x2_1) + abs(y2_2 - y2_1); float dis_2 = abs(x1_2 - x1_1) + abs(y1_2 - y1_1); float dis_3 = abs(x2_2 - x2_1) + abs(y2_2 - y2_1); - return dis + min(dis_2, dis_3); + return dis + std::min(dis_2, dis_3); +} + +void PaddleStructure::reset_timer() { + this->time_info_det = {0, 0, 0}; + this->time_info_rec = {0, 0, 0}; + this->time_info_cls = {0, 0, 0}; + this->time_info_table = {0, 0, 0}; + this->time_info_layout = {0, 0, 0}; +} + +void PaddleStructure::benchmark_log(int img_num) { + if (this->time_info_det[0] + this->time_info_det[1] + this->time_info_det[2] > + 0) { + AutoLogger autolog_det("ocr_det", FLAGS_use_gpu, FLAGS_use_tensorrt, + FLAGS_enable_mkldnn, FLAGS_cpu_threads, 1, "dynamic", + FLAGS_precision, this->time_info_det, img_num); + autolog_det.report(); + } + if (this->time_info_rec[0] + this->time_info_rec[1] + this->time_info_rec[2] > + 0) { + AutoLogger autolog_rec("ocr_rec", FLAGS_use_gpu, FLAGS_use_tensorrt, + FLAGS_enable_mkldnn, FLAGS_cpu_threads, + FLAGS_rec_batch_num, "dynamic", FLAGS_precision, + this->time_info_rec, img_num); + autolog_rec.report(); + } + if (this->time_info_cls[0] + this->time_info_cls[1] + this->time_info_cls[2] > + 0) { + AutoLogger autolog_cls("ocr_cls", FLAGS_use_gpu, FLAGS_use_tensorrt, + FLAGS_enable_mkldnn, FLAGS_cpu_threads, + FLAGS_cls_batch_num, "dynamic", FLAGS_precision, + this->time_info_cls, img_num); + autolog_cls.report(); + } + if (this->time_info_table[0] + this->time_info_table[1] + + this->time_info_table[2] > + 0) { + AutoLogger autolog_table("table", FLAGS_use_gpu, FLAGS_use_tensorrt, + FLAGS_enable_mkldnn, FLAGS_cpu_threads, + FLAGS_cls_batch_num, "dynamic", FLAGS_precision, + this->time_info_table, img_num); + autolog_table.report(); + } + if (this->time_info_layout[0] + this->time_info_layout[1] + + this->time_info_layout[2] > + 0) { + AutoLogger autolog_layout("layout", FLAGS_use_gpu, FLAGS_use_tensorrt, + FLAGS_enable_mkldnn, FLAGS_cpu_threads, + FLAGS_cls_batch_num, "dynamic", FLAGS_precision, + this->time_info_layout, img_num); + autolog_layout.report(); + } } PaddleStructure::~PaddleStructure() { - if (this->recognizer_ != nullptr) { - delete this->recognizer_; + if (this->table_model_ != nullptr) { + delete this->table_model_; } }; diff --git a/deploy/cpp_infer/src/postprocess_op.cpp b/deploy/cpp_infer/src/postprocess_op.cpp index 551f98a1668124f83ef615f0a41b081508898d6e..c139fa7236856fa653b21bc7df5914290df0e21c 100644 --- a/deploy/cpp_infer/src/postprocess_op.cpp +++ b/deploy/cpp_infer/src/postprocess_op.cpp @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #include namespace PaddleOCR { @@ -352,8 +351,21 @@ std::vector>> DBPostProcessor::FilterTagDetRes( return root_points; } -void TablePostProcessor::init(std::string label_path) { +void TablePostProcessor::init(std::string label_path, + bool merge_no_span_structure) { this->label_list_ = Utility::ReadDict(label_path); + if (merge_no_span_structure) { + this->label_list_.push_back(""); + std::vector::iterator it; + for (it = this->label_list_.begin(); it != this->label_list_.end();) { + if (*it == "") { + it = this->label_list_.erase(it); + } else { + ++it; + } + } + } + // add_special_char this->label_list_.insert(this->label_list_.begin(), this->beg); this->label_list_.push_back(this->end); } @@ -363,12 +375,12 @@ void TablePostProcessor::Run( std::vector &rec_scores, std::vector &loc_preds_shape, std::vector &structure_probs_shape, std::vector> &rec_html_tag_batch, - std::vector>>> &rec_boxes_batch, + std::vector>> &rec_boxes_batch, std::vector &width_list, std::vector &height_list) { for (int batch_idx = 0; batch_idx < structure_probs_shape[0]; batch_idx++) { // image tags and boxs std::vector rec_html_tags; - std::vector>> rec_boxes; + std::vector> rec_boxes; float score = 0.f; int count = 0; @@ -378,7 +390,7 @@ void TablePostProcessor::Run( // step for (int step_idx = 0; step_idx < structure_probs_shape[1]; step_idx++) { std::string html_tag; - std::vector> rec_box; + std::vector rec_box; // html tag int step_start_idx = (batch_idx * structure_probs_shape[1] + step_idx) * structure_probs_shape[2]; @@ -399,24 +411,26 @@ void TablePostProcessor::Run( count += 1; score += char_score; rec_html_tags.push_back(html_tag); + // box if (html_tag == "" || html_tag == "") { - for (int point_idx = 0; point_idx < loc_preds_shape[2]; - point_idx += 2) { - std::vector point(2, 0); + for (int point_idx = 0; point_idx < loc_preds_shape[2]; point_idx++) { step_start_idx = (batch_idx * structure_probs_shape[1] + step_idx) * loc_preds_shape[2] + point_idx; - point[0] = int(loc_preds[step_start_idx] * width_list[batch_idx]); - point[1] = - int(loc_preds[step_start_idx + 1] * height_list[batch_idx]); + float point = loc_preds[step_start_idx]; + if (point_idx % 2 == 0) { + point = int(point * width_list[batch_idx]); + } else { + point = int(point * height_list[batch_idx]); + } rec_box.push_back(point); } rec_boxes.push_back(rec_box); } } score /= count; - if (isnan(score) || rec_boxes.size() == 0) { + if (std::isnan(score) || rec_boxes.size() == 0) { score = -1; } rec_scores.push_back(score); @@ -425,4 +439,137 @@ void TablePostProcessor::Run( } } +void PicodetPostProcessor::init(std::string label_path, + const double score_threshold, + const double nms_threshold, + const std::vector &fpn_stride) { + this->label_list_ = Utility::ReadDict(label_path); + this->score_threshold_ = score_threshold; + this->nms_threshold_ = nms_threshold; + this->num_class_ = label_list_.size(); + this->fpn_stride_ = fpn_stride; +} + +void PicodetPostProcessor::Run(std::vector &results, + std::vector> outs, + std::vector ori_shape, + std::vector resize_shape, int reg_max) { + int in_h = resize_shape[0]; + int in_w = resize_shape[1]; + float scale_factor_h = resize_shape[0] / float(ori_shape[0]); + float scale_factor_w = resize_shape[1] / float(ori_shape[1]); + + std::vector> bbox_results; + bbox_results.resize(this->num_class_); + for (int i = 0; i < this->fpn_stride_.size(); ++i) { + int feature_h = std::ceil((float)in_h / this->fpn_stride_[i]); + int feature_w = std::ceil((float)in_w / this->fpn_stride_[i]); + for (int idx = 0; idx < feature_h * feature_w; idx++) { + // score and label + float score = 0; + int cur_label = 0; + for (int label = 0; label < this->num_class_; label++) { + if (outs[i][idx * this->num_class_ + label] > score) { + score = outs[i][idx * this->num_class_ + label]; + cur_label = label; + } + } + // bbox + if (score > this->score_threshold_) { + int row = idx / feature_w; + int col = idx % feature_w; + std::vector bbox_pred( + outs[i + this->fpn_stride_.size()].begin() + idx * 4 * reg_max, + outs[i + this->fpn_stride_.size()].begin() + + (idx + 1) * 4 * reg_max); + bbox_results[cur_label].push_back( + this->disPred2Bbox(bbox_pred, cur_label, score, col, row, + this->fpn_stride_[i], resize_shape, reg_max)); + } + } + } + for (int i = 0; i < bbox_results.size(); i++) { + bool flag = bbox_results[i].size() <= 0; + } + for (int i = 0; i < bbox_results.size(); i++) { + bool flag = bbox_results[i].size() <= 0; + if (bbox_results[i].size() <= 0) { + continue; + } + this->nms(bbox_results[i], this->nms_threshold_); + for (auto box : bbox_results[i]) { + box.box[0] = box.box[0] / scale_factor_w; + box.box[2] = box.box[2] / scale_factor_w; + box.box[1] = box.box[1] / scale_factor_h; + box.box[3] = box.box[3] / scale_factor_h; + results.push_back(box); + } + } +} + +StructurePredictResult +PicodetPostProcessor::disPred2Bbox(std::vector bbox_pred, int label, + float score, int x, int y, int stride, + std::vector im_shape, int reg_max) { + float ct_x = (x + 0.5) * stride; + float ct_y = (y + 0.5) * stride; + std::vector dis_pred; + dis_pred.resize(4); + for (int i = 0; i < 4; i++) { + float dis = 0; + std::vector bbox_pred_i(bbox_pred.begin() + i * reg_max, + bbox_pred.begin() + (i + 1) * reg_max); + std::vector dis_after_sm = + Utility::activation_function_softmax(bbox_pred_i); + for (int j = 0; j < reg_max; j++) { + dis += j * dis_after_sm[j]; + } + dis *= stride; + dis_pred[i] = dis; + } + + float xmin = (std::max)(ct_x - dis_pred[0], .0f); + float ymin = (std::max)(ct_y - dis_pred[1], .0f); + float xmax = (std::min)(ct_x + dis_pred[2], (float)im_shape[1]); + float ymax = (std::min)(ct_y + dis_pred[3], (float)im_shape[0]); + + StructurePredictResult result_item; + result_item.box = {xmin, ymin, xmax, ymax}; + result_item.type = this->label_list_[label]; + result_item.confidence = score; + + return result_item; +} + +void PicodetPostProcessor::nms(std::vector &input_boxes, + float nms_threshold) { + std::sort(input_boxes.begin(), input_boxes.end(), + [](StructurePredictResult a, StructurePredictResult b) { + return a.confidence > b.confidence; + }); + std::vector picked(input_boxes.size(), 1); + + for (int i = 0; i < input_boxes.size(); ++i) { + if (picked[i] == 0) { + continue; + } + for (int j = i + 1; j < input_boxes.size(); ++j) { + if (picked[j] == 0) { + continue; + } + float iou = Utility::iou(input_boxes[i].box, input_boxes[j].box); + if (iou > nms_threshold) { + picked[j] = 0; + } + } + } + std::vector input_boxes_nms; + for (int i = 0; i < input_boxes.size(); ++i) { + if (picked[i] == 1) { + input_boxes_nms.push_back(input_boxes[i]); + } + } + input_boxes = input_boxes_nms; +} + } // namespace PaddleOCR diff --git a/deploy/cpp_infer/src/preprocess_op.cpp b/deploy/cpp_infer/src/preprocess_op.cpp index ac185e22d68955ef440e22c327b835dbce6c4e1b..19cd6c3f799e66c50a004881272e0c4a1e357c1d 100644 --- a/deploy/cpp_infer/src/preprocess_op.cpp +++ b/deploy/cpp_infer/src/preprocess_op.cpp @@ -12,21 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "opencv2/core.hpp" -#include "opencv2/imgcodecs.hpp" -#include "opencv2/imgproc.hpp" -#include "paddle_api.h" -#include "paddle_inference_api.h" -#include -#include -#include -#include -#include - -#include -#include -#include - #include namespace PaddleOCR { @@ -69,13 +54,13 @@ void Normalize::Run(cv::Mat *im, const std::vector &mean, } void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img, - string limit_type, int limit_side_len, float &ratio_h, - float &ratio_w, bool use_tensorrt) { + std::string limit_type, int limit_side_len, + float &ratio_h, float &ratio_w, bool use_tensorrt) { int w = img.cols; int h = img.rows; float ratio = 1.f; if (limit_type == "min") { - int min_wh = min(h, w); + int min_wh = std::min(h, w); if (min_wh < limit_side_len) { if (h < w) { ratio = float(limit_side_len) / float(h); @@ -84,7 +69,7 @@ void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img, } } } else { - int max_wh = max(h, w); + int max_wh = std::max(h, w); if (max_wh > limit_side_len) { if (h > w) { ratio = float(limit_side_len) / float(h); @@ -97,8 +82,8 @@ void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img, int resize_h = int(float(h) * ratio); int resize_w = int(float(w) * ratio); - resize_h = max(int(round(float(resize_h) / 32) * 32), 32); - resize_w = max(int(round(float(resize_w) / 32) * 32), 32); + resize_h = std::max(int(round(float(resize_h) / 32) * 32), 32); + resize_w = std::max(int(round(float(resize_w) / 32) * 32), 32); cv::resize(img, resize_img, cv::Size(resize_w, resize_h)); ratio_h = float(resize_h) / float(h); @@ -175,4 +160,9 @@ void TablePadImg::Run(const cv::Mat &img, cv::Mat &resize_img, cv::BORDER_CONSTANT, cv::Scalar(0, 0, 0)); } +void Resize::Run(const cv::Mat &img, cv::Mat &resize_img, const int h, + const int w) { + cv::resize(img, resize_img, cv::Size(w, h)); +} + } // namespace PaddleOCR diff --git a/deploy/cpp_infer/src/structure_layout.cpp b/deploy/cpp_infer/src/structure_layout.cpp new file mode 100644 index 0000000000000000000000000000000000000000..922959ae0238f01a0e9ce1bec41daba0a2c71669 --- /dev/null +++ b/deploy/cpp_infer/src/structure_layout.cpp @@ -0,0 +1,149 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +namespace PaddleOCR { + +void StructureLayoutRecognizer::Run(cv::Mat img, + std::vector &result, + std::vector ×) { + std::chrono::duration preprocess_diff = + std::chrono::steady_clock::now() - std::chrono::steady_clock::now(); + std::chrono::duration inference_diff = + std::chrono::steady_clock::now() - std::chrono::steady_clock::now(); + std::chrono::duration postprocess_diff = + std::chrono::steady_clock::now() - std::chrono::steady_clock::now(); + + // preprocess + auto preprocess_start = std::chrono::steady_clock::now(); + + cv::Mat srcimg; + img.copyTo(srcimg); + cv::Mat resize_img; + this->resize_op_.Run(srcimg, resize_img, 800, 608); + this->normalize_op_.Run(&resize_img, this->mean_, this->scale_, + this->is_scale_); + + std::vector input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f); + this->permute_op_.Run(&resize_img, input.data()); + auto preprocess_end = std::chrono::steady_clock::now(); + preprocess_diff += preprocess_end - preprocess_start; + + // inference. + auto input_names = this->predictor_->GetInputNames(); + auto input_t = this->predictor_->GetInputHandle(input_names[0]); + input_t->Reshape({1, 3, resize_img.rows, resize_img.cols}); + auto inference_start = std::chrono::steady_clock::now(); + input_t->CopyFromCpu(input.data()); + + this->predictor_->Run(); + + // Get output tensor + std::vector> out_tensor_list; + std::vector> output_shape_list; + auto output_names = this->predictor_->GetOutputNames(); + for (int j = 0; j < output_names.size(); j++) { + auto output_tensor = this->predictor_->GetOutputHandle(output_names[j]); + std::vector output_shape = output_tensor->shape(); + int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1, + std::multiplies()); + output_shape_list.push_back(output_shape); + + std::vector out_data; + out_data.resize(out_num); + output_tensor->CopyToCpu(out_data.data()); + out_tensor_list.push_back(out_data); + } + auto inference_end = std::chrono::steady_clock::now(); + inference_diff += inference_end - inference_start; + + // postprocess + auto postprocess_start = std::chrono::steady_clock::now(); + + std::vector bbox_num; + int reg_max = 0; + for (int i = 0; i < out_tensor_list.size(); i++) { + if (i == this->post_processor_.fpn_stride_.size()) { + reg_max = output_shape_list[i][2] / 4; + break; + } + } + std::vector ori_shape = {srcimg.rows, srcimg.cols}; + std::vector resize_shape = {resize_img.rows, resize_img.cols}; + this->post_processor_.Run(result, out_tensor_list, ori_shape, resize_shape, + reg_max); + bbox_num.push_back(result.size()); + + auto postprocess_end = std::chrono::steady_clock::now(); + postprocess_diff += postprocess_end - postprocess_start; + times.push_back(double(preprocess_diff.count() * 1000)); + times.push_back(double(inference_diff.count() * 1000)); + times.push_back(double(postprocess_diff.count() * 1000)); +} + +void StructureLayoutRecognizer::LoadModel(const std::string &model_dir) { + paddle_infer::Config config; + if (Utility::PathExists(model_dir + "/inference.pdmodel") && + Utility::PathExists(model_dir + "/inference.pdiparams")) { + config.SetModel(model_dir + "/inference.pdmodel", + model_dir + "/inference.pdiparams"); + } else if (Utility::PathExists(model_dir + "/model.pdmodel") && + Utility::PathExists(model_dir + "/model.pdiparams")) { + config.SetModel(model_dir + "/model.pdmodel", + model_dir + "/model.pdiparams"); + } else { + std::cerr << "[ERROR] not find model.pdiparams or inference.pdiparams in " + << model_dir << std::endl; + exit(1); + } + + if (this->use_gpu_) { + config.EnableUseGpu(this->gpu_mem_, this->gpu_id_); + if (this->use_tensorrt_) { + auto precision = paddle_infer::Config::Precision::kFloat32; + if (this->precision_ == "fp16") { + precision = paddle_infer::Config::Precision::kHalf; + } + if (this->precision_ == "int8") { + precision = paddle_infer::Config::Precision::kInt8; + } + config.EnableTensorRtEngine(1 << 20, 10, 3, precision, false, false); + if (!Utility::PathExists("./trt_layout_shape.txt")) { + config.CollectShapeRangeInfo("./trt_layout_shape.txt"); + } else { + config.EnableTunedTensorRtDynamicShape("./trt_layout_shape.txt", true); + } + } + } else { + config.DisableGpu(); + if (this->use_mkldnn_) { + config.EnableMKLDNN(); + } + config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_); + } + + // false for zero copy tensor + config.SwitchUseFeedFetchOps(false); + // true for multiple input + config.SwitchSpecifyInputNames(true); + + config.SwitchIrOptim(true); + + config.EnableMemoryOptim(); + config.DisableGlogInfo(); + + this->predictor_ = paddle_infer::CreatePredictor(config); +} +} // namespace PaddleOCR diff --git a/deploy/cpp_infer/src/structure_table.cpp b/deploy/cpp_infer/src/structure_table.cpp index bbc32580e49d6ed7b29e3f0931eab0b0969b02b9..52f5d9ee9e46d88fd6e34bbb3afe86cbf7858140 100644 --- a/deploy/cpp_infer/src/structure_table.cpp +++ b/deploy/cpp_infer/src/structure_table.cpp @@ -20,7 +20,7 @@ void StructureTableRecognizer::Run( std::vector img_list, std::vector> &structure_html_tags, std::vector &structure_scores, - std::vector>>> &structure_boxes, + std::vector>> &structure_boxes, std::vector ×) { std::chrono::duration preprocess_diff = std::chrono::steady_clock::now() - std::chrono::steady_clock::now(); @@ -34,7 +34,7 @@ void StructureTableRecognizer::Run( beg_img_no += this->table_batch_num_) { // preprocess auto preprocess_start = std::chrono::steady_clock::now(); - int end_img_no = min(img_num, beg_img_no + this->table_batch_num_); + int end_img_no = std::min(img_num, beg_img_no + this->table_batch_num_); int batch_num = end_img_no - beg_img_no; std::vector norm_img_batch; std::vector width_list; @@ -89,8 +89,7 @@ void StructureTableRecognizer::Run( auto postprocess_start = std::chrono::steady_clock::now(); std::vector> structure_html_tag_batch; std::vector structure_score_batch; - std::vector>>> - structure_boxes_batch; + std::vector>> structure_boxes_batch; this->post_processor_.Run(loc_preds, structure_probs, structure_score_batch, predict_shape0, predict_shape1, structure_html_tag_batch, structure_boxes_batch, @@ -119,7 +118,7 @@ void StructureTableRecognizer::Run( } void StructureTableRecognizer::LoadModel(const std::string &model_dir) { - AnalysisConfig config; + paddle_infer::Config config; config.SetModel(model_dir + "/inference.pdmodel", model_dir + "/inference.pdiparams"); @@ -134,6 +133,11 @@ void StructureTableRecognizer::LoadModel(const std::string &model_dir) { precision = paddle_infer::Config::Precision::kInt8; } config.EnableTensorRtEngine(1 << 20, 10, 3, precision, false, false); + if (!Utility::PathExists("./trt_table_shape.txt")) { + config.CollectShapeRangeInfo("./trt_table_shape.txt"); + } else { + config.EnableTunedTensorRtDynamicShape("./trt_table_shape.txt", true); + } } } else { config.DisableGpu(); @@ -153,6 +157,6 @@ void StructureTableRecognizer::LoadModel(const std::string &model_dir) { config.EnableMemoryOptim(); config.DisableGlogInfo(); - this->predictor_ = CreatePredictor(config); + this->predictor_ = paddle_infer::CreatePredictor(config); } } // namespace PaddleOCR diff --git a/deploy/cpp_infer/src/utility.cpp b/deploy/cpp_infer/src/utility.cpp index 4bfc1d091d6124b10c79032beb702ba8727210fc..4a8b181494fca768b153e0825e8be0853f7f3aef 100644 --- a/deploy/cpp_infer/src/utility.cpp +++ b/deploy/cpp_infer/src/utility.cpp @@ -65,6 +65,38 @@ void Utility::VisualizeBboxes(const cv::Mat &srcimg, << std::endl; } +void Utility::VisualizeBboxes(const cv::Mat &srcimg, + const StructurePredictResult &structure_result, + const std::string &save_path) { + cv::Mat img_vis; + srcimg.copyTo(img_vis); + img_vis = crop_image(img_vis, structure_result.box); + for (int n = 0; n < structure_result.cell_box.size(); n++) { + if (structure_result.cell_box[n].size() == 8) { + cv::Point rook_points[4]; + for (int m = 0; m < structure_result.cell_box[n].size(); m += 2) { + rook_points[m / 2] = + cv::Point(int(structure_result.cell_box[n][m]), + int(structure_result.cell_box[n][m + 1])); + } + const cv::Point *ppt[1] = {rook_points}; + int npt[] = {4}; + cv::polylines(img_vis, ppt, npt, 1, 1, CV_RGB(0, 255, 0), 2, 8, 0); + } else if (structure_result.cell_box[n].size() == 4) { + cv::Point rook_points[2]; + rook_points[0] = cv::Point(int(structure_result.cell_box[n][0]), + int(structure_result.cell_box[n][1])); + rook_points[1] = cv::Point(int(structure_result.cell_box[n][2]), + int(structure_result.cell_box[n][3])); + cv::rectangle(img_vis, rook_points[0], rook_points[1], CV_RGB(0, 255, 0), + 2, 8, 0); + } + } + + cv::imwrite(save_path, img_vis); + std::cout << "The table visualized image saved in " + save_path << std::endl; +} + // list all files under a directory void Utility::GetAllFiles(const char *dir_name, std::vector &all_inputs) { @@ -249,32 +281,145 @@ void Utility::print_result(const std::vector &ocr_result) { } } -cv::Mat Utility::crop_image(cv::Mat &img, std::vector &area) { +cv::Mat Utility::crop_image(cv::Mat &img, const std::vector &box) { cv::Mat crop_im; - int crop_x1 = std::max(0, area[0]); - int crop_y1 = std::max(0, area[1]); - int crop_x2 = std::min(img.cols - 1, area[2] - 1); - int crop_y2 = std::min(img.rows - 1, area[3] - 1); + int crop_x1 = std::max(0, box[0]); + int crop_y1 = std::max(0, box[1]); + int crop_x2 = std::min(img.cols - 1, box[2] - 1); + int crop_y2 = std::min(img.rows - 1, box[3] - 1); - crop_im = cv::Mat::zeros(area[3] - area[1], area[2] - area[0], 16); + crop_im = cv::Mat::zeros(box[3] - box[1], box[2] - box[0], 16); cv::Mat crop_im_window = - crop_im(cv::Range(crop_y1 - area[1], crop_y2 + 1 - area[1]), - cv::Range(crop_x1 - area[0], crop_x2 + 1 - area[0])); + crop_im(cv::Range(crop_y1 - box[1], crop_y2 + 1 - box[1]), + cv::Range(crop_x1 - box[0], crop_x2 + 1 - box[0])); cv::Mat roi_img = img(cv::Range(crop_y1, crop_y2 + 1), cv::Range(crop_x1, crop_x2 + 1)); crop_im_window += roi_img; return crop_im; } +cv::Mat Utility::crop_image(cv::Mat &img, const std::vector &box) { + std::vector box_int = {(int)box[0], (int)box[1], (int)box[2], + (int)box[3]}; + return crop_image(img, box_int); +} + void Utility::sorted_boxes(std::vector &ocr_result) { std::sort(ocr_result.begin(), ocr_result.end(), Utility::comparison_box); - - for (int i = 0; i < ocr_result.size() - 1; i++) { - if (abs(ocr_result[i + 1].box[0][1] - ocr_result[i].box[0][1]) < 10 && - (ocr_result[i + 1].box[0][0] < ocr_result[i].box[0][0])) { - std::swap(ocr_result[i], ocr_result[i + 1]); + if (ocr_result.size() > 0) { + for (int i = 0; i < ocr_result.size() - 1; i++) { + for (int j = i; j > 0; j--) { + if (abs(ocr_result[j + 1].box[0][1] - ocr_result[j].box[0][1]) < 10 && + (ocr_result[j + 1].box[0][0] < ocr_result[j].box[0][0])) { + std::swap(ocr_result[i], ocr_result[i + 1]); + } + } } } } +std::vector Utility::xyxyxyxy2xyxy(std::vector> &box) { + int x_collect[4] = {box[0][0], box[1][0], box[2][0], box[3][0]}; + int y_collect[4] = {box[0][1], box[1][1], box[2][1], box[3][1]}; + int left = int(*std::min_element(x_collect, x_collect + 4)); + int right = int(*std::max_element(x_collect, x_collect + 4)); + int top = int(*std::min_element(y_collect, y_collect + 4)); + int bottom = int(*std::max_element(y_collect, y_collect + 4)); + std::vector box1(4, 0); + box1[0] = left; + box1[1] = top; + box1[2] = right; + box1[3] = bottom; + return box1; +} + +std::vector Utility::xyxyxyxy2xyxy(std::vector &box) { + int x_collect[4] = {box[0], box[2], box[4], box[6]}; + int y_collect[4] = {box[1], box[3], box[5], box[7]}; + int left = int(*std::min_element(x_collect, x_collect + 4)); + int right = int(*std::max_element(x_collect, x_collect + 4)); + int top = int(*std::min_element(y_collect, y_collect + 4)); + int bottom = int(*std::max_element(y_collect, y_collect + 4)); + std::vector box1(4, 0); + box1[0] = left; + box1[1] = top; + box1[2] = right; + box1[3] = bottom; + return box1; +} + +float Utility::fast_exp(float x) { + union { + uint32_t i; + float f; + } v{}; + v.i = (1 << 23) * (1.4426950409 * x + 126.93490512f); + return v.f; +} + +std::vector +Utility::activation_function_softmax(std::vector &src) { + int length = src.size(); + std::vector dst; + dst.resize(length); + const float alpha = float(*std::max_element(&src[0], &src[0 + length])); + float denominator{0}; + + for (int i = 0; i < length; ++i) { + dst[i] = fast_exp(src[i] - alpha); + denominator += dst[i]; + } + + for (int i = 0; i < length; ++i) { + dst[i] /= denominator; + } + return dst; +} + +float Utility::iou(std::vector &box1, std::vector &box2) { + int area1 = std::max(0, box1[2] - box1[0]) * std::max(0, box1[3] - box1[1]); + int area2 = std::max(0, box2[2] - box2[0]) * std::max(0, box2[3] - box2[1]); + + // computing the sum_area + int sum_area = area1 + area2; + + // find the each point of intersect rectangle + int x1 = std::max(box1[0], box2[0]); + int y1 = std::max(box1[1], box2[1]); + int x2 = std::min(box1[2], box2[2]); + int y2 = std::min(box1[3], box2[3]); + + // judge if there is an intersect + if (y1 >= y2 || x1 >= x2) { + return 0.0; + } else { + int intersect = (x2 - x1) * (y2 - y1); + return intersect / (sum_area - intersect + 0.00000001); + } +} + +float Utility::iou(std::vector &box1, std::vector &box2) { + float area1 = std::max((float)0.0, box1[2] - box1[0]) * + std::max((float)0.0, box1[3] - box1[1]); + float area2 = std::max((float)0.0, box2[2] - box2[0]) * + std::max((float)0.0, box2[3] - box2[1]); + + // computing the sum_area + float sum_area = area1 + area2; + + // find the each point of intersect rectangle + float x1 = std::max(box1[0], box2[0]); + float y1 = std::max(box1[1], box2[1]); + float x2 = std::min(box1[2], box2[2]); + float y2 = std::min(box1[3], box2[3]); + + // judge if there is an intersect + if (y1 >= y2 || x1 >= x2) { + return 0.0; + } else { + float intersect = (x2 - x1) * (y2 - y1); + return intersect / (sum_area - intersect + 0.00000001); + } +} + } // namespace PaddleOCR \ No newline at end of file diff --git a/deploy/lite/config.txt b/deploy/lite/config.txt index dda0d2b0320544d3a82f59b0672c086c64d83d3d..404249323b6cb5de345438056a9a10abd64b38bc 100644 --- a/deploy/lite/config.txt +++ b/deploy/lite/config.txt @@ -5,4 +5,4 @@ det_db_unclip_ratio 1.6 det_db_use_dilate 0 det_use_polygon_score 1 use_direction_classify 1 -rec_image_height 32 \ No newline at end of file +rec_image_height 48 \ No newline at end of file diff --git a/deploy/lite/readme.md b/deploy/lite/readme.md index a1bef8120e52dd91db0fda4ac2a4d91cc2800818..fc91cbfa7d69f6a8c1086243e4df3f820bd78339 100644 --- a/deploy/lite/readme.md +++ b/deploy/lite/readme.md @@ -99,6 +99,8 @@ The following table also provides a series of models that can be deployed on mob |Version|Introduction|Model size|Detection model|Text Direction model|Recognition model|Paddle-Lite branch| |---|---|---|---|---|---|---| +|PP-OCRv3|extra-lightweight chinese OCR optimized model|16.2M|[download link](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.nb)|[download link](https://paddleocr.bj.bcebos.com/PP-OCRv2/lite/ch_ppocr_mobile_v2.0_cls_infer_opt.nb)|[download link](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.nb)|v2.10| +|PP-OCRv3(slim)|extra-lightweight chinese OCR optimized model|5.9M|[download link](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_slim_infer.nb)|[download link](https://paddleocr.bj.bcebos.com/PP-OCRv2/lite/ch_ppocr_mobile_v2.0_cls_slim_opt.nb)|[download link](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_slim_infer.nb)|v2.10| |PP-OCRv2|extra-lightweight chinese OCR optimized model|11M|[download link](https://paddleocr.bj.bcebos.com/PP-OCRv2/lite/ch_PP-OCRv2_det_infer_opt.nb)|[download link](https://paddleocr.bj.bcebos.com/PP-OCRv2/lite/ch_ppocr_mobile_v2.0_cls_infer_opt.nb)|[download link](https://paddleocr.bj.bcebos.com/PP-OCRv2/lite/ch_PP-OCRv2_rec_infer_opt.nb)|v2.10| |PP-OCRv2(slim)|extra-lightweight chinese OCR optimized model|4.6M|[download link](https://paddleocr.bj.bcebos.com/PP-OCRv2/lite/ch_PP-OCRv2_det_slim_opt.nb)|[download link](https://paddleocr.bj.bcebos.com/PP-OCRv2/lite/ch_ppocr_mobile_v2.0_cls_slim_opt.nb)|[download link](https://paddleocr.bj.bcebos.com/PP-OCRv2/lite/ch_PP-OCRv2_rec_slim_opt.nb)|v2.10| @@ -134,17 +136,16 @@ Introduction to paddle_lite_opt parameters: The following takes the ultra-lightweight Chinese model of PaddleOCR as an example to introduce the use of the compiled opt file to complete the conversion of the inference model to the Paddle-Lite optimized model ``` -# 【[Recommendation] Download the Chinese and English inference model of PP-OCRv2 -wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_slim_quant_infer.tar && tar xf ch_PP-OCRv2_det_slim_quant_infer.tar -wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_slim_quant_infer.tar && tar xf ch_PP-OCRv2_rec_slim_quant_infer.tar +# 【[Recommendation] Download the Chinese and English inference model of PP-OCRv3 +wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_slim_infer.tar && tar xf ch_PP-OCRv3_det_slim_infer.tar +wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_slim_infer.tar && tar xf ch_PP-OCRv2_rec_slim_quant_infer.tar wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_cls_slim_infer.tar && tar xf ch_ppocr_mobile_v2.0_cls_slim_infer.tar # Convert detection model -./opt --model_file=./ch_PP-OCRv2_det_slim_quant_infer/inference.pdmodel --param_file=./ch_PP-OCRv2_det_slim_quant_infer/inference.pdiparams --optimize_out=./ch_PP-OCRv2_det_slim_opt --valid_targets=arm --optimize_out_type=naive_buffer +paddle_lite_opt --model_file=./ch_PP-OCRv3_det_slim_infer/inference.pdmodel --param_file=./ch_PP-OCRv3_det_slim_infer/inference.pdiparams --optimize_out=./ch_PP-OCRv3_det_slim_opt --valid_targets=arm --optimize_out_type=naive_buffer # Convert recognition model -./opt --model_file=./ch_PP-OCRv2_rec_slim_quant_infer/inference.pdmodel --param_file=./ch_PP-OCRv2_rec_slim_quant_infer/inference.pdiparams --optimize_out=./ch_PP-OCRv2_rec_slim_opt --valid_targets=arm --optimize_out_type=naive_buffer +paddle_lite_opt --model_file=./ch_PP-OCRv3_rec_slim_infer/inference.pdmodel --param_file=./ch_PP-OCRv3_rec_slim_infer/inference.pdiparams --optimize_out=./ch_PP-OCRv3_rec_slim_opt --valid_targets=arm --optimize_out_type=naive_buffer # Convert angle classifier model -./opt --model_file=./ch_ppocr_mobile_v2.0_cls_slim_infer/inference.pdmodel --param_file=./ch_ppocr_mobile_v2.0_cls_slim_infer/inference.pdiparams --optimize_out=./ch_ppocr_mobile_v2.0_cls_slim_opt --valid_targets=arm --optimize_out_type=naive_buffer - +paddle_lite_opt --model_file=./ch_ppocr_mobile_v2.0_cls_slim_infer/inference.pdmodel --param_file=./ch_ppocr_mobile_v2.0_cls_slim_infer/inference.pdiparams --optimize_out=./ch_ppocr_mobile_v2.0_cls_slim_opt --valid_targets=arm --optimize_out_type=naive_buffer ``` After the conversion is successful, there will be more files ending with `.nb` in the inference model directory, which is the successfully converted model file. @@ -197,15 +198,15 @@ Some preparatory work is required first. cp ../../../cxx/lib/libpaddle_light_api_shared.so ./debug/ ``` -Prepare the test image, taking PaddleOCR/doc/imgs/11.jpg as an example, copy the image file to the demo/cxx/ocr/debug/ folder. Prepare the model files optimized by the lite opt tool, ch_det_mv3_db_opt.nb, ch_rec_mv3_crnn_opt.nb, and place them under the demo/cxx/ocr/debug/ folder. +Prepare the test image, taking PaddleOCR/doc/imgs/11.jpg as an example, copy the image file to the demo/cxx/ocr/debug/ folder. Prepare the model files optimized by the lite opt tool, ch_PP-OCRv3_det_slim_opt.nb , ch_PP-OCRv3_rec_slim_opt.nb , and place them under the demo/cxx/ocr/debug/ folder. The structure of the OCR demo is as follows after the above command is executed: ``` demo/cxx/ocr/ |-- debug/ -| |--ch_PP-OCRv2_det_slim_opt.nb Detection model -| |--ch_PP-OCRv2_rec_slim_opt.nb Recognition model +| |--ch_PP-OCRv3_det_slim_opt.nb Detection model +| |--ch_PP-OCRv3_rec_slim_opt.nb Recognition model | |--ch_ppocr_mobile_v2.0_cls_slim_opt.nb Text direction classification model | |--11.jpg Image for OCR | |--ppocr_keys_v1.txt Dictionary file @@ -240,7 +241,7 @@ det_db_thresh 0.3 # Used to filter the binarized image of DB prediction, det_db_box_thresh 0.5 # DDB post-processing filter box threshold, if there is a missing box detected, it can be reduced as appropriate det_db_unclip_ratio 1.6 # Indicates the compactness of the text box, the smaller the value, the closer the text box to the text use_direction_classify 0 # Whether to use the direction classifier, 0 means not to use, 1 means to use -rec_image_height 32 # The height of the input image of the recognition model, the PP-OCRv3 model needs to be set to 48, and the PP-OCRv2 model needs to be set to 32 +rec_image_height 48 # The height of the input image of the recognition model, the PP-OCRv3 model needs to be set to 48, and the PP-OCRv2 model needs to be set to 32 ``` 5. Run Model on phone @@ -260,14 +261,14 @@ After the above steps are completed, you can use adb to push the file to the pho export LD_LIBRARY_PATH=${PWD}:$LD_LIBRARY_PATH # The use of ocr_db_crnn is: # ./ocr_db_crnn Mode Detection model file Orientation classifier model file Recognition model file Hardware Precision Threads Batchsize Test image path Dictionary file path - ./ocr_db_crnn system ch_PP-OCRv2_det_slim_opt.nb ch_PP-OCRv2_rec_slim_opt.nb ch_ppocr_mobile_v2.0_cls_slim_opt.nb arm8 INT8 10 1 ./11.jpg config.txt ppocr_keys_v1.txt True + ./ocr_db_crnn system ch_PP-OCRv3_det_slim_opt.nb ch_PP-OCRv3_rec_slim_opt.nb ch_ppocr_mobile_v2.0_cls_slim_opt.nb arm8 INT8 10 1 ./11.jpg config.txt ppocr_keys_v1.txt True # precision can be INT8 for quantitative model or FP32 for normal model. # Only using detection model -./ocr_db_crnn det ch_PP-OCRv2_det_slim_opt.nb arm8 INT8 10 1 ./11.jpg config.txt +./ocr_db_crnn det ch_PP-OCRv3_det_slim_opt.nb arm8 INT8 10 1 ./11.jpg config.txt # Only using recognition model -./ocr_db_crnn rec ch_PP-OCRv2_rec_slim_opt.nb arm8 INT8 10 1 word_1.jpg ppocr_keys_v1.txt config.txt +./ocr_db_crnn rec ch_PP-OCRv3_rec_slim_opt.nb arm8 INT8 10 1 word_1.jpg ppocr_keys_v1.txt config.txt ``` If you modify the code, you need to recompile and push to the phone. diff --git a/deploy/lite/readme_ch.md b/deploy/lite/readme_ch.md index 0793827fe647c470944fc36e2b243c8f7e704e99..78e2510917e0fd85c4a724ec74eccb0b7cfc6118 100644 --- a/deploy/lite/readme_ch.md +++ b/deploy/lite/readme_ch.md @@ -97,6 +97,8 @@ Paddle-Lite 提供了多种策略来自动优化原始的模型,其中包括 |模型版本|模型简介|模型大小|检测模型|文本方向分类模型|识别模型|Paddle-Lite版本| |---|---|---|---|---|---|---| +|PP-OCRv3|蒸馏版超轻量中文OCR移动端模型|16.2M|[下载地址](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.nb)|[下载地址](https://paddleocr.bj.bcebos.com/PP-OCRv2/lite/ch_ppocr_mobile_v2.0_cls_infer_opt.nb)|[下载地址](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.nb)|v2.10| +|PP-OCRv3(slim)|蒸馏版超轻量中文OCR移动端模型|5.9M|[下载地址](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_slim_infer.nb)|[下载地址](https://paddleocr.bj.bcebos.com/PP-OCRv2/lite/ch_ppocr_mobile_v2.0_cls_slim_opt.nb)|[下载地址](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_slim_infer.nb)|v2.10| |PP-OCRv2|蒸馏版超轻量中文OCR移动端模型|11M|[下载地址](https://paddleocr.bj.bcebos.com/PP-OCRv2/lite/ch_PP-OCRv2_det_infer_opt.nb)|[下载地址](https://paddleocr.bj.bcebos.com/PP-OCRv2/lite/ch_ppocr_mobile_v2.0_cls_infer_opt.nb)|[下载地址](https://paddleocr.bj.bcebos.com/PP-OCRv2/lite/ch_PP-OCRv2_rec_infer_opt.nb)|v2.10| |PP-OCRv2(slim)|蒸馏版超轻量中文OCR移动端模型|4.6M|[下载地址](https://paddleocr.bj.bcebos.com/PP-OCRv2/lite/ch_PP-OCRv2_det_slim_opt.nb)|[下载地址](https://paddleocr.bj.bcebos.com/PP-OCRv2/lite/ch_ppocr_mobile_v2.0_cls_slim_opt.nb)|[下载地址](https://paddleocr.bj.bcebos.com/PP-OCRv2/lite/ch_PP-OCRv2_rec_slim_opt.nb)|v2.10| @@ -131,16 +133,16 @@ paddle_lite_opt 参数介绍: 下面以PaddleOCR的超轻量中文模型为例,介绍使用编译好的opt文件完成inference模型到Paddle-Lite优化模型的转换。 ``` -# 【推荐】 下载 PP-OCRv2版本的中英文 inference模型 -wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_slim_quant_infer.tar && tar xf ch_PP-OCRv2_det_slim_quant_infer.tar -wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_slim_quant_infer.tar && tar xf ch_PP-OCRv2_rec_slim_quant_infer.tar +# 【推荐】 下载 PP-OCRv3版本的中英文 inference模型 +wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_slim_infer.tar && tar xf ch_PP-OCRv3_det_slim_infer.tar +wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_slim_infer.tar && tar xf ch_PP-OCRv2_rec_slim_quant_infer.tar wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_cls_slim_infer.tar && tar xf ch_ppocr_mobile_v2.0_cls_slim_infer.tar # 转换检测模型 -./opt --model_file=./ch_PP-OCRv2_det_slim_quant_infer/inference.pdmodel --param_file=./ch_PP-OCRv2_det_slim_quant_infer/inference.pdiparams --optimize_out=./ch_PP-OCRv2_det_slim_opt --valid_targets=arm --optimize_out_type=naive_buffer +paddle_lite_opt --model_file=./ch_PP-OCRv3_det_slim_infer/inference.pdmodel --param_file=./ch_PP-OCRv3_det_slim_infer/inference.pdiparams --optimize_out=./ch_PP-OCRv3_det_slim_opt --valid_targets=arm --optimize_out_type=naive_buffer # 转换识别模型 -./opt --model_file=./ch_PP-OCRv2_rec_slim_quant_infer/inference.pdmodel --param_file=./ch_PP-OCRv2_rec_slim_quant_infer/inference.pdiparams --optimize_out=./ch_PP-OCRv2_rec_slim_opt --valid_targets=arm --optimize_out_type=naive_buffer +paddle_lite_opt --model_file=./ch_PP-OCRv3_rec_slim_infer/inference.pdmodel --param_file=./ch_PP-OCRv3_rec_slim_infer/inference.pdiparams --optimize_out=./ch_PP-OCRv3_rec_slim_opt --valid_targets=arm --optimize_out_type=naive_buffer # 转换方向分类器模型 -./opt --model_file=./ch_ppocr_mobile_v2.0_cls_slim_infer/inference.pdmodel --param_file=./ch_ppocr_mobile_v2.0_cls_slim_infer/inference.pdiparams --optimize_out=./ch_ppocr_mobile_v2.0_cls_slim_opt --valid_targets=arm --optimize_out_type=naive_buffer +paddle_lite_opt --model_file=./ch_ppocr_mobile_v2.0_cls_slim_infer/inference.pdmodel --param_file=./ch_ppocr_mobile_v2.0_cls_slim_infer/inference.pdiparams --optimize_out=./ch_ppocr_mobile_v2.0_cls_slim_opt --valid_targets=arm --optimize_out_type=naive_buffer ``` @@ -194,15 +196,15 @@ wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_cls ``` 准备测试图像,以`PaddleOCR/doc/imgs/11.jpg`为例,将测试的图像复制到`demo/cxx/ocr/debug/`文件夹下。 - 准备lite opt工具优化后的模型文件,比如使用`ch_PP-OCRv2_det_slim_opt.ch_PP-OCRv2_rec_slim_rec.nb, ch_ppocr_mobile_v2.0_cls_slim_opt.nb`,模型文件放置在`demo/cxx/ocr/debug/`文件夹下。 + 准备lite opt工具优化后的模型文件,比如使用`ch_PP-OCRv3_det_slim_opt.ch_PP-OCRv3_rec_slim_rec.nb, ch_ppocr_mobile_v2.0_cls_slim_opt.nb`,模型文件放置在`demo/cxx/ocr/debug/`文件夹下。 执行完成后,ocr文件夹下将有如下文件格式: ``` demo/cxx/ocr/ |-- debug/ -| |--ch_PP-OCRv2_det_slim_opt.nb 优化后的检测模型文件 -| |--ch_PP-OCRv2_rec_slim_opt.nb 优化后的识别模型文件 +| |--ch_PP-OCRv3_det_slim_opt.nb 优化后的检测模型文件 +| |--ch_PP-OCRv3_rec_slim_opt.nb 优化后的识别模型文件 | |--ch_ppocr_mobile_v2.0_cls_slim_opt.nb 优化后的文字方向分类器模型文件 | |--11.jpg 待测试图像 | |--ppocr_keys_v1.txt 中文字典文件 @@ -239,7 +241,7 @@ det_db_thresh 0.3 # 用于过滤DB预测的二值化图像,设置为0. det_db_box_thresh 0.5 # 检测器后处理过滤box的阈值,如果检测存在漏框情况,可酌情减小 det_db_unclip_ratio 1.6 # 表示文本框的紧致程度,越小则文本框更靠近文本 use_direction_classify 0 # 是否使用方向分类器,0表示不使用,1表示使用 -rec_image_height 32 # 识别模型输入图像的高度,PP-OCRv3模型设置为48,PP-OCRv2模型需要设置为32 +rec_image_height 48 # 识别模型输入图像的高度,PP-OCRv3模型设置为48,PP-OCRv2模型需要设置为32 ``` 5. 启动调试 @@ -259,13 +261,13 @@ rec_image_height 32 # 识别模型输入图像的高度,PP-OCRv3模型 export LD_LIBRARY_PATH=${PWD}:$LD_LIBRARY_PATH # 开始使用,ocr_db_crnn可执行文件的使用方式为: # ./ocr_db_crnn 预测模式 检测模型文件 方向分类器模型文件 识别模型文件 运行硬件 运行精度 线程数 batchsize 测试图像路径 参数配置路径 字典文件路径 是否使用benchmark参数 - ./ocr_db_crnn system ch_PP-OCRv2_det_slim_opt.nb ch_PP-OCRv2_rec_slim_opt.nb ch_ppocr_mobile_v2.0_cls_slim_opt.nb arm8 INT8 10 1 ./11.jpg config.txt ppocr_keys_v1.txt True + ./ocr_db_crnn system ch_PP-OCRv3_det_slim_opt.nb ch_PP-OCRv3_rec_slim_opt.nb ch_ppocr_mobile_v2.0_cls_slim_opt.nb arm8 INT8 10 1 ./11.jpg config.txt ppocr_keys_v1.txt True # 仅使用文本检测模型,使用方式如下: -./ocr_db_crnn det ch_PP-OCRv2_det_slim_opt.nb arm8 INT8 10 1 ./11.jpg config.txt +./ocr_db_crnn det ch_PP-OCRv3_det_slim_opt.nb arm8 INT8 10 1 ./11.jpg config.txt # 仅使用文本识别模型,使用方式如下: -./ocr_db_crnn rec ch_PP-OCRv2_rec_slim_opt.nb arm8 INT8 10 1 word_1.jpg ppocr_keys_v1.txt config.txt +./ocr_db_crnn rec ch_PP-OCRv3_rec_slim_opt.nb arm8 INT8 10 1 word_1.jpg ppocr_keys_v1.txt config.txt ``` 如果对代码做了修改,则需要重新编译并push到手机上。 diff --git a/deploy/slim/quantization/README.md b/deploy/slim/quantization/README.md index 4c1d784b99aade614d78b4bd6fb20afef15f0f6f..7f1ff7ae22e78cded28f1689d66a5e41dd8950a2 100644 --- a/deploy/slim/quantization/README.md +++ b/deploy/slim/quantization/README.md @@ -22,7 +22,7 @@ ### 1. 安装PaddleSlim ```bash -pip3 install paddleslim==2.2.2 +pip3 install paddleslim==2.3.2 ``` ### 2. 准备训练好的模型 @@ -33,17 +33,7 @@ PaddleOCR提供了一系列训练好的[模型](../../../doc/doc_ch/models_list. 量化训练包括离线量化训练和在线量化训练,在线量化训练效果更好,需加载预训练模型,在定义好量化策略后即可对模型进行量化。 -量化训练的代码位于slim/quantization/quant.py 中,比如训练检测模型,训练指令如下: -```bash -python deploy/slim/quantization/quant.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o Global.pretrained_model='your trained model' Global.save_model_dir=./output/quant_model - -# 比如下载提供的训练模型 -wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_train.tar -tar -xf ch_ppocr_mobile_v2.0_det_train.tar -python deploy/slim/quantization/quant.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o Global.pretrained_model=./ch_ppocr_mobile_v2.0_det_train/best_accuracy Global.save_model_dir=./output/quant_model -``` - -模型蒸馏和模型量化可以同时使用,以PPOCRv3检测模型为例: +量化训练的代码位于slim/quantization/quant.py 中,比如训练检测模型,以PPOCRv3检测模型为例,训练指令如下: ``` # 下载检测预训练模型: wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_distill_train.tar @@ -58,7 +48,7 @@ python deploy/slim/quantization/quant.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_ 在得到量化训练保存的模型后,我们可以将其导出为inference_model,用于预测部署: ```bash -python deploy/slim/quantization/export_model.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o Global.checkpoints=output/quant_model/best_accuracy Global.save_inference_dir=./output/quant_inference_model +python deploy/slim/quantization/export_model.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml -o Global.checkpoints=output/quant_model/best_accuracy Global.save_inference_dir=./output/quant_inference_model ``` ### 5. 量化模型部署 diff --git a/deploy/slim/quantization/README_en.md b/deploy/slim/quantization/README_en.md index c6796ae9dc256496308e432023c45ef1026c3d92..f82c3d844e292ee76b95624f7632ed40301e5a4c 100644 --- a/deploy/slim/quantization/README_en.md +++ b/deploy/slim/quantization/README_en.md @@ -25,7 +25,7 @@ After training, if you want to further compress the model size and accelerate th ### 1. Install PaddleSlim ```bash -pip3 install paddleslim==2.2.2 +pip3 install paddleslim==2.3.2 ``` @@ -39,18 +39,7 @@ Quantization training includes offline quantization training and online quantiza Online quantization training is more effective. It is necessary to load the pre-trained model. After the quantization strategy is defined, the model can be quantified. -The code for quantization training is located in `slim/quantization/quant.py`. For example, to train a detection model, the training instructions are as follows: -```bash -python deploy/slim/quantization/quant.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o Global.pretrained_model='your trained model' Global.save_model_dir=./output/quant_model - -# download provided model -wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_train.tar -tar -xf ch_ppocr_mobile_v2.0_det_train.tar -python deploy/slim/quantization/quant.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o Global.pretrained_model=./ch_ppocr_mobile_v2.0_det_train/best_accuracy Global.save_model_dir=./output/quant_model -``` - - -Model distillation and model quantization can be used at the same time, taking the PPOCRv3 detection model as an example: +The code for quantization training is located in `slim/quantization/quant.py`. For example, the training instructions of slim PPOCRv3 detection model are as follows: ``` # download provided model wget https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_distill_train.tar @@ -66,7 +55,7 @@ If you want to quantify the text recognition model, you can modify the configura Once we got the model after pruning and fine-tuning, we can export it as an inference model for the deployment of predictive tasks: ```bash -python deploy/slim/quantization/export_model.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o Global.checkpoints=output/quant_model/best_accuracy Global.save_inference_dir=./output/quant_inference_model +python deploy/slim/quantization/export_model.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml -o Global.checkpoints=output/quant_model/best_accuracy Global.save_inference_dir=./output/quant_inference_model ``` ### 5. Deploy diff --git a/deploy/slim/quantization/export_model.py b/deploy/slim/quantization/export_model.py index fd1c3e5e109667fa74f5ade18b78f634e4d325db..bd132b625181cab853961efd2e2c38c411e9edf4 100755 --- a/deploy/slim/quantization/export_model.py +++ b/deploy/slim/quantization/export_model.py @@ -151,17 +151,24 @@ def main(): arch_config = config["Architecture"] - arch_config = config["Architecture"] + if arch_config["algorithm"] == "SVTR" and arch_config["Head"][ + "name"] != 'MultiHead': + input_shape = config["Eval"]["dataset"]["transforms"][-2][ + 'SVTRRecResizeImg']['image_shape'] + else: + input_shape = None if arch_config["algorithm"] in ["Distillation", ]: # distillation model archs = list(arch_config["Models"].values()) for idx, name in enumerate(model.model_name_list): sub_model_save_path = os.path.join(save_path, name, "inference") export_single_model(model.model_list[idx], archs[idx], - sub_model_save_path, logger, quanter) + sub_model_save_path, logger, input_shape, + quanter) else: save_path = os.path.join(save_path, "inference") - export_single_model(model, arch_config, save_path, logger, quanter) + export_single_model(model, arch_config, save_path, logger, input_shape, + quanter) if __name__ == "__main__": diff --git a/deploy/slim/quantization/quant.py b/deploy/slim/quantization/quant.py index 64521b5e06df61cf656da4087e6cd49f82adfadd..ef2c3e28f94e8b72d1aa7822fc88ecfd5c406b89 100755 --- a/deploy/slim/quantization/quant.py +++ b/deploy/slim/quantization/quant.py @@ -158,8 +158,7 @@ def main(config, device, logger, vdl_writer): pre_best_model_dict = dict() # load fp32 model to begin quantization - if config["Global"]["pretrained_model"] is not None: - pre_best_model_dict = load_model(config, model) + pre_best_model_dict = load_model(config, model, None, config['Architecture']["model_type"]) freeze_params = False if config['Architecture']["algorithm"] in ["Distillation"]: @@ -184,8 +183,7 @@ def main(config, device, logger, vdl_writer): model=model) # resume PACT training process - if config["Global"]["checkpoints"] is not None: - pre_best_model_dict = load_model(config, model, optimizer) + pre_best_model_dict = load_model(config, model, optimizer, config['Architecture']["model_type"]) # build metric eval_class = build_metric(config['Metric']) diff --git a/deploy/slim/quantization/quant_kl.py b/deploy/slim/quantization/quant_kl.py index cc3a455b971937fbb2e401b87112475341bd41f3..73e1a957e8606fd7cc8269e96eec1e274484db06 100755 --- a/deploy/slim/quantization/quant_kl.py +++ b/deploy/slim/quantization/quant_kl.py @@ -97,6 +97,17 @@ def sample_generator(loader): return __reader__ +def sample_generator_layoutxlm_ser(loader): + def __reader__(): + for indx, data in enumerate(loader): + input_ids = np.array(data[0]) + bbox = np.array(data[1]) + attention_mask = np.array(data[2]) + token_type_ids = np.array(data[3]) + images = np.array(data[4]) + yield [input_ids, bbox, attention_mask, token_type_ids, images] + + return __reader__ def main(config, device, logger, vdl_writer): # init dist environment @@ -107,16 +118,18 @@ def main(config, device, logger, vdl_writer): # build dataloader config['Train']['loader']['num_workers'] = 0 + is_layoutxlm_ser = config['Architecture']['model_type'] =='kie' and config['Architecture']['Backbone']['name'] == 'LayoutXLMForSer' train_dataloader = build_dataloader(config, 'Train', device, logger) if config['Eval']: config['Eval']['loader']['num_workers'] = 0 valid_dataloader = build_dataloader(config, 'Eval', device, logger) + if is_layoutxlm_ser: + train_dataloader = valid_dataloader else: valid_dataloader = None paddle.enable_static() - place = paddle.CPUPlace() - exe = paddle.static.Executor(place) + exe = paddle.static.Executor(device) if 'inference_model' in global_config.keys(): # , 'inference_model'): inference_model_dir = global_config['inference_model'] @@ -127,6 +140,11 @@ def main(config, device, logger, vdl_writer): raise ValueError( "Please set inference model dir in Global.inference_model or Global.pretrained_model for post-quantazition" ) + + if is_layoutxlm_ser: + generator = sample_generator_layoutxlm_ser(train_dataloader) + else: + generator = sample_generator(train_dataloader) paddleslim.quant.quant_post_static( executor=exe, @@ -134,7 +152,7 @@ def main(config, device, logger, vdl_writer): model_filename='inference.pdmodel', params_filename='inference.pdiparams', quantize_model_path=global_config['save_inference_dir'], - sample_generator=sample_generator(train_dataloader), + sample_generator=generator, save_model_filename='inference.pdmodel', save_params_filename='inference.pdiparams', batch_size=1, diff --git a/doc/doc_ch/algorithm_det_ct.md b/doc/doc_ch/algorithm_det_ct.md new file mode 100644 index 0000000000000000000000000000000000000000..ea3522b7bf3c2dc17ef4f645bc47738477f07cf1 --- /dev/null +++ b/doc/doc_ch/algorithm_det_ct.md @@ -0,0 +1,95 @@ +# CT + +- [1. 算法简介](#1) +- [2. 环境配置](#2) +- [3. 模型训练、评估、预测](#3) + - [3.1 训练](#3-1) + - [3.2 评估](#3-2) + - [3.3 预测](#3-3) +- [4. 推理部署](#4) + - [4.1 Python推理](#4-1) + - [4.2 C++推理](#4-2) + - [4.3 Serving服务化部署](#4-3) + - [4.4 更多推理部署](#4-4) +- [5. FAQ](#5) + + +## 1. 算法简介 + +论文信息: +> [CentripetalText: An Efficient Text Instance Representation for Scene Text Detection](https://arxiv.org/abs/2107.05945) +> Tao Sheng, Jie Chen, Zhouhui Lian +> NeurIPS, 2021 + + +在Total-Text文本检测公开数据集上,算法复现效果如下: + +|模型|骨干网络|配置文件|precision|recall|Hmean|下载链接| +| --- | --- | --- | --- | --- | --- | --- | +|CT|ResNet18_vd|[configs/det/det_r18_vd_ct.yml](../../configs/det/det_r18_vd_ct.yml)|88.68%|81.70%|85.05%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r18_ct_train.tar)| + + + +## 2. 环境配置 +请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。 + + + +## 3. 模型训练、评估、预测 + +CT模型使用Total-Text文本检测公开数据集训练得到,数据集下载可参考 [Total-Text-Dataset](https://github.com/cs-chan/Total-Text-Dataset/tree/master/Dataset), 我们将标签文件转成了paddleocr格式,转换好的标签文件下载参考[train.txt](https://paddleocr.bj.bcebos.com/dataset/ct_tipc/train.txt), [text.txt](https://paddleocr.bj.bcebos.com/dataset/ct_tipc/test.txt)。 + +请参考[文本检测训练教程](./detection.md)。PaddleOCR对代码进行了模块化,训练不同的检测模型只需要**更换配置文件**即可。 + + + +## 4. 推理部署 + + +### 4.1 Python推理 +首先将CT文本检测训练过程中保存的模型,转换成inference model。以基于Resnet18_vd骨干网络,在Total-Text英文数据集训练的模型为例( [模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r18_ct_train.tar) ),可以使用如下命令进行转换: + +```shell +python3 tools/export_model.py -c configs/det/det_r18_vd_ct.yml -o Global.pretrained_model=./det_r18_ct_train/best_accuracy Global.save_inference_dir=./inference/det_ct +``` + +CT文本检测模型推理,可以执行如下命令: + +```shell +python3 tools/infer/predict_det.py --image_dir="./doc/imgs_en/img623.jpg" --det_model_dir="./inference/det_ct/" --det_algorithm="CT" +``` + +可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下: + +![](../imgs_results/det_res_img623_ct.jpg) + + + +### 4.2 C++推理 + +暂不支持 + + +### 4.3 Serving服务化部署 + +暂不支持 + + +### 4.4 更多推理部署 + +暂不支持 + + +## 5. FAQ + + +## 引用 + +```bibtex +@inproceedings{sheng2021centripetaltext, + title={CentripetalText: An Efficient Text Instance Representation for Scene Text Detection}, + author={Tao Sheng and Jie Chen and Zhouhui Lian}, + booktitle={Thirty-Fifth Conference on Neural Information Processing Systems}, + year={2021} +} +``` diff --git a/doc/doc_ch/algorithm_kie_layoutxlm.md b/doc/doc_ch/algorithm_kie_layoutxlm.md index e693be49b7bc89e04b169fe74cf76525b2494948..0cbcad25016974207382a044e211c704082f6467 100644 --- a/doc/doc_ch/algorithm_kie_layoutxlm.md +++ b/doc/doc_ch/algorithm_kie_layoutxlm.md @@ -30,7 +30,7 @@ |模型|骨干网络|任务|配置文件|hmean|下载链接| | --- | --- |--|--- | --- | --- | |LayoutXLM|LayoutXLM-base|SER |[ser_layoutxlm_xfund_zh.yml](../../configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml)|90.38%|[训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar)/[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh_infer.tar)| -|LayoutXLM|LayoutXLM-base|RE | [re_layoutxlm_xfund_zh.yml](../../configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml)|74.83%|[训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar)/[推理模型(coming soon)]()| +|LayoutXLM|LayoutXLM-base|RE | [re_layoutxlm_xfund_zh.yml](../../configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml)|74.83%|[训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar)/[推理模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh_infer.tar)| @@ -52,14 +52,14 @@ ### 4.1 Python推理 -**注:** 目前RE任务推理过程仍在适配中,下面以SER任务为例,介绍基于LayoutXLM模型的关键信息抽取过程。 +- SER 首先将训练得到的模型转换成inference model。LayoutXLM模型在XFUND_zh数据集上训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar)),可以使用下面的命令进行转换。 ``` bash wget https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar tar -xf ser_LayoutXLM_xfun_zh.tar -python3 tools/export_model.py -c configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./ser_LayoutXLM_xfun_zh/best_accuracy Global.save_inference_dir=./inference/ser_layoutxlm +python3 tools/export_model.py -c configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./ser_LayoutXLM_xfun_zh Global.save_inference_dir=./inference/ser_layoutxlm_infer ``` LayoutXLM模型基于SER任务进行推理,可以执行如下命令: @@ -80,6 +80,34 @@ SER可视化结果默认保存到`./output`文件夹里面,结果示例如下 +- RE + +首先将训练得到的模型转换成inference model。LayoutXLM模型在XFUND_zh数据集上训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar)),可以使用下面的命令进行转换。 + +``` bash +wget https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar +tar -xf re_LayoutXLM_xfun_zh.tar +python3 tools/export_model.py -c configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./re_LayoutXLM_xfun_zh Global.save_inference_dir=./inference/ser_layoutxlm_infer +``` + +LayoutXLM模型基于RE任务进行推理,可以执行如下命令: + +```bash +cd ppstructure +python3 kie/predict_kie_token_ser_re.py \ + --kie_algorithm=LayoutXLM \ + --re_model_dir=../inference/re_layoutxlm_infer \ + --ser_model_dir=../inference/ser_layoutxlm_infer \ + --image_dir=./docs/kie/input/zh_val_42.jpg \ + --ser_dict_path=../train_data/XFUND/class_list_xfun.txt \ + --vis_font_path=../doc/fonts/simfang.ttf +``` + +RE可视化结果默认保存到`./output`文件夹里面,结果示例如下: + +
+ +
### 4.2 C++推理部署 diff --git a/doc/doc_ch/algorithm_kie_vi_layoutxlm.md b/doc/doc_ch/algorithm_kie_vi_layoutxlm.md index f1bb4b1e62736e88594196819dcc41980f1716bf..1ec778a899d8e3e66164f6d1deb902bca36ad65a 100644 --- a/doc/doc_ch/algorithm_kie_vi_layoutxlm.md +++ b/doc/doc_ch/algorithm_kie_vi_layoutxlm.md @@ -23,7 +23,7 @@ VI-LayoutXLM基于LayoutXLM进行改进,在下游任务训练过程中,去 |模型|骨干网络|任务|配置文件|hmean|下载链接| | --- | --- |---| --- | --- | --- | |VI-LayoutXLM |VI-LayoutXLM-base | SER |[ser_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh_udml.yml)|93.19%|[训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_pretrained.tar)/[推理模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar)| -|VI-LayoutXLM |VI-LayoutXLM-base |RE | [re_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml)|83.92%|[训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar)/[推理模型(coming soon)]()| +|VI-LayoutXLM |VI-LayoutXLM-base |RE | [re_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml)|83.92%|[训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar)/[推理模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_infer.tar)| @@ -45,7 +45,7 @@ VI-LayoutXLM基于LayoutXLM进行改进,在下游任务训练过程中,去 ### 4.1 Python推理 -**注:** 目前RE任务推理过程仍在适配中,下面以SER任务为例,介绍基于VI-LayoutXLM模型的关键信息抽取过程。 +- SER 首先将训练得到的模型转换成inference model。以VI-LayoutXLM模型在XFUND_zh数据集上训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_pretrained.tar)),可以使用下面的命令进行转换。 @@ -74,6 +74,36 @@ SER可视化结果默认保存到`./output`文件夹里面,结果示例如下 +- RE + +首先将训练得到的模型转换成inference model。以VI-LayoutXLM模型在XFUND_zh数据集上训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar)),可以使用下面的命令进行转换。 + +``` bash +wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar +tar -xf re_vi_layoutxlm_xfund_pretrained.tar +python3 tools/export_model.py -c configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./re_vi_layoutxlm_xfund_pretrained/best_accuracy Global.save_inference_dir=./inference/re_vi_layoutxlm_infer +``` + +VI-LayoutXLM模型基于RE任务进行推理,可以执行如下命令: + +```bash +cd ppstructure +python3 kie/predict_kie_token_ser_re.py \ + --kie_algorithm=LayoutXLM \ + --re_model_dir=../inference/re_vi_layoutxlm_infer \ + --ser_model_dir=../inference/ser_vi_layoutxlm_infer \ + --use_visual_backbone=False \ + --image_dir=./docs/kie/input/zh_val_42.jpg \ + --ser_dict_path=../train_data/XFUND/class_list_xfun.txt \ + --vis_font_path=../doc/fonts/simfang.ttf \ + --ocr_order_method="tb-yx" +``` + +RE可视化结果默认保存到`./output`文件夹里面,结果示例如下: + +
+ +
### 4.2 C++推理部署 diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md index ecb0e9dfefbfdef2f8cea273c4e3de468aa29415..b9b8cfa6798b5d0bf5b33e562c65996bf54c8c7c 100755 --- a/doc/doc_ch/algorithm_overview.md +++ b/doc/doc_ch/algorithm_overview.md @@ -100,8 +100,8 @@ PaddleOCR将**持续新增**支持OCR领域前沿算法与模型,**欢迎广 |ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ce | [训练模型](https://paddleocr.bj.bcebos.com/rec_vitstr_none_ce_train.tar) | |ABINet|Resnet45| 90.75% | rec_r45_abinet | [训练模型](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) | |VisionLAN|Resnet45| 90.30% | rec_r45_visionlan | [训练模型](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar) | -|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | coming soon | -|RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | coming soon | +|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | [训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_r32_gaspin_bilstm_att.tar) | +|RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | [训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_r31_robustscanner.tar)| diff --git a/doc/doc_ch/algorithm_rec_robustscanner.md b/doc/doc_ch/algorithm_rec_robustscanner.md index 869f9a7c00b617de87ab3c96326e18e536bc18a8..a1ab3baf038f573c50fcc0b67fd8297459777cf8 100644 --- a/doc/doc_ch/algorithm_rec_robustscanner.md +++ b/doc/doc_ch/algorithm_rec_robustscanner.md @@ -26,7 +26,7 @@ Zhang |模型|骨干网络|配置文件|Acc|下载链接| | --- | --- | --- | --- | --- | -|RobustScanner|ResNet31|[rec_r31_robustscanner.yml](../../configs/rec/rec_r31_robustscanner.yml)|87.77%|coming soon| +|RobustScanner|ResNet31|[rec_r31_robustscanner.yml](../../configs/rec/rec_r31_robustscanner.yml)|87.77%|[训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_r31_robustscanner.tar)| 注:除了使用MJSynth和SynthText两个文字识别数据集外,还加入了[SynthAdd](https://pan.baidu.com/share/init?surl=uV0LtoNmcxbO-0YA7Ch4dg)数据(提取码:627x),和部分真实数据,具体数据细节可以参考论文。 diff --git a/doc/doc_ch/algorithm_rec_spin.md b/doc/doc_ch/algorithm_rec_spin.md index c996992d2fa6297e6086ffae4bc36ad3e880873d..908a85a417c4070b95630b37b0830e08aae3ff4f 100644 --- a/doc/doc_ch/algorithm_rec_spin.md +++ b/doc/doc_ch/algorithm_rec_spin.md @@ -26,7 +26,7 @@ SPIN收录于AAAI2020。主要用于OCR识别任务。在任意形状文本识 |模型|骨干网络|配置文件|Acc|下载链接| | --- | --- | --- | --- | --- | -|SPIN|ResNet32|[rec_r32_gaspin_bilstm_att.yml](../../configs/rec/rec_r32_gaspin_bilstm_att.yml)|90.0%|coming soon| +|SPIN|ResNet32|[rec_r32_gaspin_bilstm_att.yml](../../configs/rec/rec_r32_gaspin_bilstm_att.yml)|90.0%|[训练模型](https://paddleocr.bj.bcebos.com/contribution/rec_r32_gaspin_bilstm_att.tar)| diff --git a/doc/doc_ch/inference_args.md b/doc/doc_ch/inference_args.md index 36efc6fbf7a6ec62bc700964dc13261fecdb9bd5..24e7223e397c94fe65b0f26d993fc507b323ed16 100644 --- a/doc/doc_ch/inference_args.md +++ b/doc/doc_ch/inference_args.md @@ -7,6 +7,7 @@ | 参数名称 | 类型 | 默认值 | 含义 | | :--: | :--: | :--: | :--: | | image_dir | str | 无,必须显式指定 | 图像或者文件夹路径 | +| page_num | int | 0 | 当输入类型为pdf文件时有效,指定预测前面page_num页,默认预测所有页 | | vis_font_path | str | "./doc/fonts/simfang.ttf" | 用于可视化的字体路径 | | drop_score | float | 0.5 | 识别得分小于该值的结果会被丢弃,不会作为返回结果 | | use_pdserving | bool | False | 是否使用Paddle Serving进行预测 | diff --git a/doc/doc_en/algorithm_det_ct_en.md b/doc/doc_en/algorithm_det_ct_en.md new file mode 100644 index 0000000000000000000000000000000000000000..d56b3fc6b3353bacb1f26fba3873ba5276b10c8b --- /dev/null +++ b/doc/doc_en/algorithm_det_ct_en.md @@ -0,0 +1,96 @@ +# CT + +- [1. Introduction](#1) +- [2. Environment](#2) +- [3. Model Training / Evaluation / Prediction](#3) + - [3.1 Training](#3-1) + - [3.2 Evaluation](#3-2) + - [3.3 Prediction](#3-3) +- [4. Inference and Deployment](#4) + - [4.1 Python Inference](#4-1) + - [4.2 C++ Inference](#4-2) + - [4.3 Serving](#4-3) + - [4.4 More](#4-4) +- [5. FAQ](#5) + + +## 1. Introduction + +Paper: +> [CentripetalText: An Efficient Text Instance Representation for Scene Text Detection](https://arxiv.org/abs/2107.05945) +> Tao Sheng, Jie Chen, Zhouhui Lian +> NeurIPS, 2021 + + +On the Total-Text dataset, the text detection result is as follows: + +|Model|Backbone|Configuration|Precision|Recall|Hmean|Download| +| --- | --- | --- | --- | --- | --- | --- | +|CT|ResNet18_vd|[configs/det/det_r18_vd_ct.yml](../../configs/det/det_r18_vd_ct.yml)|88.68%|81.70%|85.05%|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r18_ct_train.tar)| + + + +## 2. Environment +Please prepare your environment referring to [prepare the environment](./environment_en.md) and [clone the repo](./clone_en.md). + + + +## 3. Model Training / Evaluation / Prediction + + +The above CT model is trained using the Total-Text text detection public dataset. For the download of the dataset, please refer to [Total-Text-Dataset](https://github.com/cs-chan/Total-Text-Dataset/tree/master/Dataset). PaddleOCR format annotation download link [train.txt](https://paddleocr.bj.bcebos.com/dataset/ct_tipc/train.txt), [test.txt](https://paddleocr.bj.bcebos.com/dataset/ct_tipc/test.txt). + + +Please refer to [text detection training tutorial](./detection_en.md). PaddleOCR has modularized the code structure, so that you only need to **replace the configuration file** to train different detection models. + + +## 4. Inference and Deployment + + +### 4.1 Python Inference +First, convert the model saved in the CT text detection training process into an inference model. Taking the model based on the Resnet18_vd backbone network and trained on the Total Text English dataset as example ([model download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r18_ct_train.tar)), you can use the following command to convert: + +```shell +python3 tools/export_model.py -c configs/det/det_r18_vd_ct.yml -o Global.pretrained_model=./det_r18_ct_train/best_accuracy Global.save_inference_dir=./inference/det_ct +``` + +CT text detection model inference, you can execute the following command: + +```shell +python3 tools/infer/predict_det.py --image_dir="./doc/imgs_en/img623.jpg" --det_model_dir="./inference/det_ct/" --det_algorithm="CT" +``` + +The visualized text detection results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'det_res'. Examples of results are as follows: + +![](../imgs_results/det_res_img623_ct.jpg) + + + +### 4.2 C++ Inference + +Not supported + + +### 4.3 Serving + +Not supported + + +### 4.4 More + +Not supported + + +## 5. FAQ + + +## Citation + +```bibtex +@inproceedings{sheng2021centripetaltext, + title={CentripetalText: An Efficient Text Instance Representation for Scene Text Detection}, + author={Tao Sheng and Jie Chen and Zhouhui Lian}, + booktitle={Thirty-Fifth Conference on Neural Information Processing Systems}, + year={2021} +} +``` diff --git a/doc/doc_en/algorithm_kie_layoutxlm_en.md b/doc/doc_en/algorithm_kie_layoutxlm_en.md index 910c1f4d497a6e503f0a7a5ec26dbeceb2d321a1..0c82b0423b2cdc11b1817e7575629a659e599374 100644 --- a/doc/doc_en/algorithm_kie_layoutxlm_en.md +++ b/doc/doc_en/algorithm_kie_layoutxlm_en.md @@ -28,7 +28,7 @@ On XFUND_zh dataset, the algorithm reproduction Hmean is as follows. |Model|Backbone|Task |Cnnfig|Hmean|Download link| | --- | --- |--|--- | --- | --- | |LayoutXLM|LayoutXLM-base|SER |[ser_layoutxlm_xfund_zh.yml](../../configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml)|90.38%|[trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar)/[inference model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh_infer.tar)| -|LayoutXLM|LayoutXLM-base|RE | [re_layoutxlm_xfund_zh.yml](../../configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml)|74.83%|[trained model](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar)/[inference model(coming soon)]()| +|LayoutXLM|LayoutXLM-base|RE | [re_layoutxlm_xfund_zh.yml](../../configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml)|74.83%|[trained model](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar)/[inference model](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh_infer.tar)| ## 2. Environment @@ -46,7 +46,7 @@ Please refer to [KIE tutorial](./kie_en.md)。PaddleOCR has modularized the code ### 4.1 Python Inference -**Note:** Currently, the RE model inference process is still in the process of adaptation. We take SER model as an example to introduce the KIE process based on LayoutXLM model. +- SER First, we need to export the trained model into inference model. Take LayoutXLM model trained on XFUND_zh as an example ([trained model download link](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar)). Use the following command to export. @@ -54,7 +54,7 @@ First, we need to export the trained model into inference model. Take LayoutXLM ``` bash wget https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar tar -xf ser_LayoutXLM_xfun_zh.tar -python3 tools/export_model.py -c configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./ser_LayoutXLM_xfun_zh/best_accuracy Global.save_inference_dir=./inference/ser_layoutxlm +python3 tools/export_model.py -c configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./ser_LayoutXLM_xfun_zh Global.save_inference_dir=./inference/ser_layoutxlm_infer ``` Use the following command to infer using LayoutXLM SER model. @@ -77,6 +77,38 @@ The SER visualization results are saved in the `./output` directory by default. +- RE + +First, we need to export the trained model into inference model. Take LayoutXLM model trained on XFUND_zh as an example ([trained model download link](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar)). Use the following command to export. + + +``` bash +wget https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar +tar -xf re_LayoutXLM_xfun_zh.tar +python3 tools/export_model.py -c configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./re_LayoutXLM_xfun_zh Global.save_inference_dir=./inference/re_layoutxlm_infer +``` + +Use the following command to infer using LayoutXLM RE model. + + +```bash +cd ppstructure +python3 kie/predict_kie_token_ser_re.py \ + --kie_algorithm=LayoutXLM \ + --re_model_dir=../inference/re_layoutxlm_infer \ + --ser_model_dir=../inference/ser_layoutxlm_infer \ + --image_dir=./docs/kie/input/zh_val_42.jpg \ + --ser_dict_path=../train_data/XFUND/class_list_xfun.txt \ + --vis_font_path=../doc/fonts/simfang.ttf +``` +The RE visualization results are saved in the `./output` directory by default. The results are as follows. + + +
+ +
+ + ### 4.2 C++ Inference Not supported diff --git a/doc/doc_en/algorithm_kie_vi_layoutxlm_en.md b/doc/doc_en/algorithm_kie_vi_layoutxlm_en.md index 12b6e1bddbd03b820ce33ba86de3d430a44f8987..798a52ca8cd9ce0de6b942294a22c11f64c9ac2f 100644 --- a/doc/doc_en/algorithm_kie_vi_layoutxlm_en.md +++ b/doc/doc_en/algorithm_kie_vi_layoutxlm_en.md @@ -22,7 +22,7 @@ On XFUND_zh dataset, the algorithm reproduction Hmean is as follows. |Model|Backbone|Task |Cnnfig|Hmean|Download link| | --- | --- |---| --- | --- | --- | |VI-LayoutXLM |VI-LayoutXLM-base | SER |[ser_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh_udml.yml)|93.19%|[trained model](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_pretrained.tar)/[inference model](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar)| -|VI-LayoutXLM |VI-LayoutXLM-base |RE | [re_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml)|83.92%|[trained model](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar)/[inference model(coming soon)]()| +|VI-LayoutXLM |VI-LayoutXLM-base |RE | [re_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml)|83.92%|[trained model](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar)/[inference model](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_infer.tar)| Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code. @@ -37,7 +37,7 @@ Please refer to [KIE tutorial](./kie_en.md)。PaddleOCR has modularized the code ### 4.1 Python Inference -**Note:** Currently, the RE model inference process is still in the process of adaptation. We take SER model as an example to introduce the KIE process based on VI-LayoutXLM model. +- SER First, we need to export the trained model into inference model. Take VI-LayoutXLM model trained on XFUND_zh as an example ([trained model download link](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_pretrained.tar)). Use the following command to export. @@ -70,6 +70,41 @@ The SER visualization results are saved in the `./output` folder by default. The +- RE + +First, we need to export the trained model into inference model. Take VI-LayoutXLM model trained on XFUND_zh as an example ([trained model download link](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar)). Use the following command to export. + + +``` bash +wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar +tar -xf re_vi_layoutxlm_xfund_pretrained.tar +python3 tools/export_model.py -c configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./re_vi_layoutxlm_xfund_pretrained/best_accuracy Global.save_inference_dir=./inference/re_vi_layoutxlm_infer +``` + +Use the following command to infer using VI-LayoutXLM RE model. + + +```bash +cd ppstructure +python3 kie/predict_kie_token_ser_re.py \ + --kie_algorithm=LayoutXLM \ + --re_model_dir=../inference/re_vi_layoutxlm_infer \ + --ser_model_dir=../inference/ser_vi_layoutxlm_infer \ + --use_visual_backbone=False \ + --image_dir=./docs/kie/input/zh_val_42.jpg \ + --ser_dict_path=../train_data/XFUND/class_list_xfun.txt \ + --vis_font_path=../doc/fonts/simfang.ttf \ + --ocr_order_method="tb-yx" +``` + +The RE visualization results are saved in the `./output` folder by default. The results are as follows. + + +
+ +
+ + ### 4.2 C++ Inference Not supported diff --git a/doc/doc_en/algorithm_overview_en.md b/doc/doc_en/algorithm_overview_en.md index bca22f78482980bed18d6447d0cf07b27c26720d..073bca1031beb9e96f73db6387386a93be419b3d 100755 --- a/doc/doc_en/algorithm_overview_en.md +++ b/doc/doc_en/algorithm_overview_en.md @@ -97,8 +97,8 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r |ViTSTR|ViTSTR| 79.82% | rec_vitstr_none_ce | [trained model](https://paddleocr.bj.bcebos.com/rec_vitstr_none_none_train.tar) | |ABINet|Resnet45| 90.75% | rec_r45_abinet | [trained model](https://paddleocr.bj.bcebos.com/rec_r45_abinet_train.tar) | |VisionLAN|Resnet45| 90.30% | rec_r45_visionlan | [trained model](https://paddleocr.bj.bcebos.com/rec_r45_visionlan_train.tar) | -|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | coming soon | -|RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | coming soon | +|SPIN|ResNet32| 90.00% | rec_r32_gaspin_bilstm_att | [trained model](https://paddleocr.bj.bcebos.com/contribution/rec_r32_gaspin_bilstm_att.tar) | +|RobustScanner|ResNet31| 87.77% | rec_r31_robustscanner | [trained model](https://paddleocr.bj.bcebos.com/contribution/rec_r31_robustscanner.tar)| diff --git a/doc/doc_en/algorithm_rec_robustscanner_en.md b/doc/doc_en/algorithm_rec_robustscanner_en.md index a324a6d547a9e448566276234c750ad4497abf9c..99372d51300e6571827db7181a4f8537db4f1493 100644 --- a/doc/doc_en/algorithm_rec_robustscanner_en.md +++ b/doc/doc_en/algorithm_rec_robustscanner_en.md @@ -26,7 +26,7 @@ Using MJSynth and SynthText two text recognition datasets for training, and eval |Model|Backbone|config|Acc|Download link| | --- | --- | --- | --- | --- | -|RobustScanner|ResNet31|[rec_r31_robustscanner.yml](../../configs/rec/rec_r31_robustscanner.yml)|87.77%|coming soon| +|RobustScanner|ResNet31|[rec_r31_robustscanner.yml](../../configs/rec/rec_r31_robustscanner.yml)|87.77%|[trained model](https://paddleocr.bj.bcebos.com/contribution/rec_r31_robustscanner.tar)| Note:In addition to using the two text recognition datasets MJSynth and SynthText, [SynthAdd](https://pan.baidu.com/share/init?surl=uV0LtoNmcxbO-0YA7Ch4dg) data (extraction code: 627x), and some real data are used in training, the specific data details can refer to the paper. diff --git a/doc/doc_en/algorithm_rec_spin_en.md b/doc/doc_en/algorithm_rec_spin_en.md index 43ab30ce7d96cbb64ddf87156fee3012d666b2bf..03f8d8f69986fc5eb14cfdf294fc25fafb06e269 100644 --- a/doc/doc_en/algorithm_rec_spin_en.md +++ b/doc/doc_en/algorithm_rec_spin_en.md @@ -25,7 +25,7 @@ Using MJSynth and SynthText two text recognition datasets for training, and eval |Model|Backbone|config|Acc|Download link| | --- | --- | --- | --- | --- | -|SPIN|ResNet32|[rec_r32_gaspin_bilstm_att.yml](../../configs/rec/rec_r32_gaspin_bilstm_att.yml)|90.0%|coming soon| +|SPIN|ResNet32|[rec_r32_gaspin_bilstm_att.yml](../../configs/rec/rec_r32_gaspin_bilstm_att.yml)|90.0%|[trained model](https://paddleocr.bj.bcebos.com/contribution/rec_r32_gaspin_bilstm_att.tar) | diff --git a/doc/doc_en/inference_args_en.md b/doc/doc_en/inference_args_en.md index f2c99fc8297d47f27a219bf7d8e7f2ea518257f0..b28cd8436da62dcd10f96f17751db9384ebcaa8d 100644 --- a/doc/doc_en/inference_args_en.md +++ b/doc/doc_en/inference_args_en.md @@ -7,6 +7,7 @@ When using PaddleOCR for model inference, you can customize the modification par | parameters | type | default | implication | | :--: | :--: | :--: | :--: | | image_dir | str | None, must be specified explicitly | Image or folder path | +| page_num | int | 0 | Valid when the input type is pdf file, specify to predict the previous page_num pages, all pages are predicted by default | | vis_font_path | str | "./doc/fonts/simfang.ttf" | font path for visualization | | drop_score | float | 0.5 | Results with a recognition score less than this value will be discarded and will not be returned as results | | use_pdserving | bool | False | Whether to use Paddle Serving for prediction | diff --git a/doc/imgs_results/det_res_img623_ct.jpg b/doc/imgs_results/det_res_img623_ct.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2c5f57d96cca896c70d9e0d33ba80a0177a8aeb9 Binary files /dev/null and b/doc/imgs_results/det_res_img623_ct.jpg differ diff --git a/paddleocr.py b/paddleocr.py index 0b7aed36279081f50208f75272fc54c5081929a7..d34b8f78a56a8d8d5455c18e7e1cf1e75df8f3f9 100644 --- a/paddleocr.py +++ b/paddleocr.py @@ -414,6 +414,33 @@ def get_model_config(type, version, model_type, lang): return model_urls[version][model_type][lang] +def img_decode(content: bytes): + np_arr = np.frombuffer(content, dtype=np.uint8) + return cv2.imdecode(np_arr, cv2.IMREAD_COLOR) + + +def check_img(img): + if isinstance(img, bytes): + img = img_decode(img) + if isinstance(img, str): + # download net image + if is_link(img): + download_with_progressbar(img, 'tmp.jpg') + img = 'tmp.jpg' + image_file = img + img, flag, _ = check_and_read(image_file) + if not flag: + with open(image_file, 'rb') as f: + img = img_decode(f.read()) + if img is None: + logger.error("error in loading image:{}".format(image_file)) + return None + if isinstance(img, np.ndarray) and len(img.shape) == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + return img + + class PaddleOCR(predict_system.TextSystem): def __init__(self, **kwargs): """ @@ -453,10 +480,11 @@ class PaddleOCR(predict_system.TextSystem): params.rec_image_shape = "3, 48, 320" else: params.rec_image_shape = "3, 32, 320" - # download model - maybe_download(params.det_model_dir, det_url) - maybe_download(params.rec_model_dir, rec_url) - maybe_download(params.cls_model_dir, cls_url) + # download model if using paddle infer + if not params.use_onnx: + maybe_download(params.det_model_dir, det_url) + maybe_download(params.rec_model_dir, rec_url) + maybe_download(params.cls_model_dir, cls_url) if params.det_algorithm not in SUPPORT_DET_MODEL: logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL)) @@ -482,7 +510,7 @@ class PaddleOCR(predict_system.TextSystem): rec: use text recognition or not. If false, only det will be exec. Default is True cls: use angle classifier or not. Default is True. If true, the text with rotation of 180 degrees can be recognized. If no text is rotated by 180 degrees, use cls=False to get better performance. Text with rotation of 90 or 270 degrees can be recognized even if cls=False. """ - assert isinstance(img, (np.ndarray, list, str)) + assert isinstance(img, (np.ndarray, list, str, bytes)) if isinstance(img, list) and det == True: logger.error('When input a list of images, det must be false') exit(0) @@ -491,22 +519,8 @@ class PaddleOCR(predict_system.TextSystem): 'Since the angle classifier is not initialized, the angle classifier will not be uesd during the forward process' ) - if isinstance(img, str): - # download net image - if img.startswith('http'): - download_with_progressbar(img, 'tmp.jpg') - img = 'tmp.jpg' - image_file = img - img, flag, _ = check_and_read(image_file) - if not flag: - with open(image_file, 'rb') as f: - np_arr = np.frombuffer(f.read(), dtype=np.uint8) - img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR) - if img is None: - logger.error("error in loading image:{}".format(image_file)) - return None - if isinstance(img, np.ndarray) and len(img.shape) == 2: - img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + img = check_img(img) + if det and rec: dt_boxes, rec_res, _ = self.__call__(img, cls) return [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)] @@ -585,23 +599,7 @@ class PPStructure(StructureSystem): super().__init__(params) def __call__(self, img, return_ocr_result_in_table=False, img_idx=0): - if isinstance(img, str): - # download net image - if img.startswith('http'): - download_with_progressbar(img, 'tmp.jpg') - img = 'tmp.jpg' - image_file = img - img, flag, _ = check_and_read(image_file) - if not flag: - with open(image_file, 'rb') as f: - np_arr = np.frombuffer(f.read(), dtype=np.uint8) - img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR) - if img is None: - logger.error("error in loading image:{}".format(image_file)) - return None - if isinstance(img, np.ndarray) and len(img.shape) == 2: - img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) - + img = check_img(img) res, _ = super().__call__( img, return_ocr_result_in_table, img_idx=img_idx) return res @@ -644,7 +642,7 @@ def main(): if not flag_pdf: if img is None: - logger.error("error in loading image:{}".format(image_file)) + logger.error("error in loading image:{}".format(img_path)) continue img_paths = [[img_path, img]] else: diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py index 102f48fcc19e59d9f8ffb0ad496f54cc64864f7d..863988cccfa9d9f2c865a444410d4245687f49ee 100644 --- a/ppocr/data/imaug/__init__.py +++ b/ppocr/data/imaug/__init__.py @@ -43,6 +43,7 @@ from .vqa import * from .fce_aug import * from .fce_targets import FCENetTargets +from .ct_process import * def transform(data, ops=None): diff --git a/ppocr/data/imaug/ct_process.py b/ppocr/data/imaug/ct_process.py new file mode 100644 index 0000000000000000000000000000000000000000..59715090036e1020800950b02b9ea06ab5c8d4c2 --- /dev/null +++ b/ppocr/data/imaug/ct_process.py @@ -0,0 +1,355 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import cv2 +import random +import pyclipper +import paddle + +import numpy as np +import Polygon as plg +import scipy.io as scio + +from PIL import Image +import paddle.vision.transforms as transforms + + +class RandomScale(): + def __init__(self, short_size=640, **kwargs): + self.short_size = short_size + + def scale_aligned(self, img, scale): + oh, ow = img.shape[0:2] + h = int(oh * scale + 0.5) + w = int(ow * scale + 0.5) + if h % 32 != 0: + h = h + (32 - h % 32) + if w % 32 != 0: + w = w + (32 - w % 32) + img = cv2.resize(img, dsize=(w, h)) + factor_h = h / oh + factor_w = w / ow + return img, factor_h, factor_w + + def __call__(self, data): + img = data['image'] + + h, w = img.shape[0:2] + random_scale = np.array([0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3]) + scale = (np.random.choice(random_scale) * self.short_size) / min(h, w) + img, factor_h, factor_w = self.scale_aligned(img, scale) + + data['scale_factor'] = (factor_w, factor_h) + data['image'] = img + return data + + +class MakeShrink(): + def __init__(self, kernel_scale=0.7, **kwargs): + self.kernel_scale = kernel_scale + + def dist(self, a, b): + return np.linalg.norm((a - b), ord=2, axis=0) + + def perimeter(self, bbox): + peri = 0.0 + for i in range(bbox.shape[0]): + peri += self.dist(bbox[i], bbox[(i + 1) % bbox.shape[0]]) + return peri + + def shrink(self, bboxes, rate, max_shr=20): + rate = rate * rate + shrinked_bboxes = [] + for bbox in bboxes: + area = plg.Polygon(bbox).area() + peri = self.perimeter(bbox) + + try: + pco = pyclipper.PyclipperOffset() + pco.AddPath(bbox, pyclipper.JT_ROUND, + pyclipper.ET_CLOSEDPOLYGON) + offset = min( + int(area * (1 - rate) / (peri + 0.001) + 0.5), max_shr) + + shrinked_bbox = pco.Execute(-offset) + if len(shrinked_bbox) == 0: + shrinked_bboxes.append(bbox) + continue + + shrinked_bbox = np.array(shrinked_bbox[0]) + if shrinked_bbox.shape[0] <= 2: + shrinked_bboxes.append(bbox) + continue + + shrinked_bboxes.append(shrinked_bbox) + except Exception as e: + shrinked_bboxes.append(bbox) + + return shrinked_bboxes + + def __call__(self, data): + img = data['image'] + bboxes = data['polys'] + words = data['texts'] + scale_factor = data['scale_factor'] + + gt_instance = np.zeros(img.shape[0:2], dtype='uint8') # h,w + training_mask = np.ones(img.shape[0:2], dtype='uint8') + training_mask_distance = np.ones(img.shape[0:2], dtype='uint8') + + for i in range(len(bboxes)): + bboxes[i] = np.reshape(bboxes[i] * ( + [scale_factor[0], scale_factor[1]] * (bboxes[i].shape[0] // 2)), + (bboxes[i].shape[0] // 2, 2)).astype('int32') + + for i in range(len(bboxes)): + #different value for different bbox + cv2.drawContours(gt_instance, [bboxes[i]], -1, i + 1, -1) + + # set training mask to 0 + cv2.drawContours(training_mask, [bboxes[i]], -1, 0, -1) + + # for not accurate annotation, use training_mask_distance + if words[i] == '###' or words[i] == '???': + cv2.drawContours(training_mask_distance, [bboxes[i]], -1, 0, -1) + + # make shrink + gt_kernel_instance = np.zeros(img.shape[0:2], dtype='uint8') + kernel_bboxes = self.shrink(bboxes, self.kernel_scale) + for i in range(len(bboxes)): + cv2.drawContours(gt_kernel_instance, [kernel_bboxes[i]], -1, i + 1, + -1) + + # for training mask, kernel and background= 1, box region=0 + if words[i] != '###' and words[i] != '???': + cv2.drawContours(training_mask, [kernel_bboxes[i]], -1, 1, -1) + + gt_kernel = gt_kernel_instance.copy() + # for gt_kernel, kernel = 1 + gt_kernel[gt_kernel > 0] = 1 + + # shrink 2 times + tmp1 = gt_kernel_instance.copy() + erode_kernel = np.ones((3, 3), np.uint8) + tmp1 = cv2.erode(tmp1, erode_kernel, iterations=1) + tmp2 = tmp1.copy() + tmp2 = cv2.erode(tmp2, erode_kernel, iterations=1) + + # compute text region + gt_kernel_inner = tmp1 - tmp2 + + # gt_instance: text instance, bg=0, diff word use diff value + # training_mask: text instance mask, word=0,kernel and bg=1 + # gt_kernel_instance: text kernel instance, bg=0, diff word use diff value + # gt_kernel: text_kernel, bg=0,diff word use same value + # gt_kernel_inner: text kernel reference + # training_mask_distance: word without anno = 0, else 1 + + data['image'] = [ + img, gt_instance, training_mask, gt_kernel_instance, gt_kernel, + gt_kernel_inner, training_mask_distance + ] + return data + + +class GroupRandomHorizontalFlip(): + def __init__(self, p=0.5, **kwargs): + self.p = p + + def __call__(self, data): + imgs = data['image'] + + if random.random() < self.p: + for i in range(len(imgs)): + imgs[i] = np.flip(imgs[i], axis=1).copy() + data['image'] = imgs + return data + + +class GroupRandomRotate(): + def __init__(self, **kwargs): + pass + + def __call__(self, data): + imgs = data['image'] + + max_angle = 10 + angle = random.random() * 2 * max_angle - max_angle + for i in range(len(imgs)): + img = imgs[i] + w, h = img.shape[:2] + rotation_matrix = cv2.getRotationMatrix2D((h / 2, w / 2), angle, 1) + img_rotation = cv2.warpAffine( + img, rotation_matrix, (h, w), flags=cv2.INTER_NEAREST) + imgs[i] = img_rotation + + data['image'] = imgs + return data + + +class GroupRandomCropPadding(): + def __init__(self, target_size=(640, 640), **kwargs): + self.target_size = target_size + + def __call__(self, data): + imgs = data['image'] + + h, w = imgs[0].shape[0:2] + t_w, t_h = self.target_size + p_w, p_h = self.target_size + if w == t_w and h == t_h: + return data + + t_h = t_h if t_h < h else h + t_w = t_w if t_w < w else w + + if random.random() > 3.0 / 8.0 and np.max(imgs[1]) > 0: + # make sure to crop the text region + tl = np.min(np.where(imgs[1] > 0), axis=1) - (t_h, t_w) + tl[tl < 0] = 0 + br = np.max(np.where(imgs[1] > 0), axis=1) - (t_h, t_w) + br[br < 0] = 0 + br[0] = min(br[0], h - t_h) + br[1] = min(br[1], w - t_w) + + i = random.randint(tl[0], br[0]) if tl[0] < br[0] else 0 + j = random.randint(tl[1], br[1]) if tl[1] < br[1] else 0 + else: + i = random.randint(0, h - t_h) if h - t_h > 0 else 0 + j = random.randint(0, w - t_w) if w - t_w > 0 else 0 + + n_imgs = [] + for idx in range(len(imgs)): + if len(imgs[idx].shape) == 3: + s3_length = int(imgs[idx].shape[-1]) + img = imgs[idx][i:i + t_h, j:j + t_w, :] + img_p = cv2.copyMakeBorder( + img, + 0, + p_h - t_h, + 0, + p_w - t_w, + borderType=cv2.BORDER_CONSTANT, + value=tuple(0 for i in range(s3_length))) + else: + img = imgs[idx][i:i + t_h, j:j + t_w] + img_p = cv2.copyMakeBorder( + img, + 0, + p_h - t_h, + 0, + p_w - t_w, + borderType=cv2.BORDER_CONSTANT, + value=(0, )) + n_imgs.append(img_p) + + data['image'] = n_imgs + return data + + +class MakeCentripetalShift(): + def __init__(self, **kwargs): + pass + + def jaccard(self, As, Bs): + A = As.shape[0] # small + B = Bs.shape[0] # large + + dis = np.sqrt( + np.sum((As[:, np.newaxis, :].repeat( + B, axis=1) - Bs[np.newaxis, :, :].repeat( + A, axis=0))**2, + axis=-1)) + + ind = np.argmin(dis, axis=-1) + + return ind + + def __call__(self, data): + imgs = data['image'] + + img, gt_instance, training_mask, gt_kernel_instance, gt_kernel, gt_kernel_inner, training_mask_distance = \ + imgs[0], imgs[1], imgs[2], imgs[3], imgs[4], imgs[5], imgs[6] + + max_instance = np.max(gt_instance) # num bbox + + # make centripetal shift + gt_distance = np.zeros((2, *img.shape[0:2]), dtype=np.float32) + for i in range(1, max_instance + 1): + # kernel_reference + ind = (gt_kernel_inner == i) + + if np.sum(ind) == 0: + training_mask[gt_instance == i] = 0 + training_mask_distance[gt_instance == i] = 0 + continue + + kpoints = np.array(np.where(ind)).transpose( + (1, 0))[:, ::-1].astype('float32') + + ind = (gt_instance == i) * (gt_kernel_instance == 0) + if np.sum(ind) == 0: + continue + pixels = np.where(ind) + + points = np.array(pixels).transpose( + (1, 0))[:, ::-1].astype('float32') + + bbox_ind = self.jaccard(points, kpoints) + + offset_gt = kpoints[bbox_ind] - points + + gt_distance[:, pixels[0], pixels[1]] = offset_gt.T * 0.1 + + img = Image.fromarray(img) + img = img.convert('RGB') + + data["image"] = img + data["gt_kernel"] = gt_kernel.astype("int64") + data["training_mask"] = training_mask.astype("int64") + data["gt_instance"] = gt_instance.astype("int64") + data["gt_kernel_instance"] = gt_kernel_instance.astype("int64") + data["training_mask_distance"] = training_mask_distance.astype("int64") + data["gt_distance"] = gt_distance.astype("float32") + + return data + + +class ScaleAlignedShort(): + def __init__(self, short_size=640, **kwargs): + self.short_size = short_size + + def __call__(self, data): + img = data['image'] + + org_img_shape = img.shape + + h, w = img.shape[0:2] + scale = self.short_size * 1.0 / min(h, w) + h = int(h * scale + 0.5) + w = int(w * scale + 0.5) + if h % 32 != 0: + h = h + (32 - h % 32) + if w % 32 != 0: + w = w + (32 - w % 32) + img = cv2.resize(img, dsize=(w, h)) + + new_img_shape = img.shape + img_shape = np.array(org_img_shape + new_img_shape) + + data['shape'] = img_shape + data['image'] = img + + return data \ No newline at end of file diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index 59cb9b8a253cf04244ebf83511ab412174487a53..dbfb93176cc782bedc8f7b33367b59046c4abec8 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -1395,3 +1395,29 @@ class VLLabelEncode(BaseRecLabelEncode): data['label_res'] = np.array(label_res) data['label_sub'] = np.array(label_sub) return data + + +class CTLabelEncode(object): + def __init__(self, **kwargs): + pass + + def __call__(self, data): + label = data['label'] + + label = json.loads(label) + nBox = len(label) + boxes, txts = [], [] + for bno in range(0, nBox): + box = label[bno]['points'] + box = np.array(box) + + boxes.append(box) + txt = label[bno]['transcription'] + txts.append(txt) + + if len(boxes) == 0: + return None + + data['polys'] = boxes + data['texts'] = txts + return data \ No newline at end of file diff --git a/ppocr/data/imaug/operators.py b/ppocr/data/imaug/operators.py index f8ed28929707eb750ad6e8499a73568cae3a8e6b..5e84b1aac9c54d8a8283468af6826ca917ba0384 100644 --- a/ppocr/data/imaug/operators.py +++ b/ppocr/data/imaug/operators.py @@ -225,6 +225,8 @@ class DetResizeForTest(object): def __call__(self, data): img = data['image'] src_h, src_w, _ = img.shape + if sum([src_h, src_w]) < 64: + img = self.image_padding(img) if self.resize_type == 0: # img, shape = self.resize_image_type0(img) @@ -238,6 +240,12 @@ class DetResizeForTest(object): data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w]) return data + def image_padding(self, im, value=0): + h, w, c = im.shape + im_pad = np.zeros((max(32, h), max(32, w), c), np.uint8) + value + im_pad[:h, :w, :] = im + return im_pad + def resize_image_type1(self, img): resize_h, resize_w = self.image_shape ori_h, ori_w = img.shape[:2] # (h, w, c) diff --git a/ppocr/data/imaug/pg_process.py b/ppocr/data/imaug/pg_process.py index 53031064c019ddce00c7546f898ac67a7f0459f9..f1e5f912b7a55dc3b9e883a9f4f8c5de482dcd5a 100644 --- a/ppocr/data/imaug/pg_process.py +++ b/ppocr/data/imaug/pg_process.py @@ -15,6 +15,8 @@ import math import cv2 import numpy as np +from skimage.morphology._skeletonize import thin +from ppocr.utils.e2e_utils.extract_textpoint_fast import sort_and_expand_with_direction_v2 __all__ = ['PGProcessTrain'] @@ -26,17 +28,24 @@ class PGProcessTrain(object): max_text_nums, tcl_len, batch_size=14, + use_resize=True, + use_random_crop=False, min_crop_size=24, min_text_size=4, max_text_size=512, + point_gather_mode=None, **kwargs): self.tcl_len = tcl_len self.max_text_length = max_text_length self.max_text_nums = max_text_nums self.batch_size = batch_size - self.min_crop_size = min_crop_size + if use_random_crop is True: + self.min_crop_size = min_crop_size + self.use_random_crop = use_random_crop self.min_text_size = min_text_size self.max_text_size = max_text_size + self.use_resize = use_resize + self.point_gather_mode = point_gather_mode self.Lexicon_Table = self.get_dict(character_dict_path) self.pad_num = len(self.Lexicon_Table) self.img_id = 0 @@ -282,6 +291,95 @@ class PGProcessTrain(object): pos_m[:keep] = 1.0 return pos_l, pos_m + def fit_and_gather_tcl_points_v3(self, + min_area_quad, + poly, + max_h, + max_w, + fixed_point_num=64, + img_id=0, + reference_height=3): + """ + Find the center point of poly as key_points, then fit and gather. + """ + det_mask = np.zeros((int(max_h / self.ds_ratio), + int(max_w / self.ds_ratio))).astype(np.float32) + + # score_big_map + cv2.fillPoly(det_mask, + np.round(poly / self.ds_ratio).astype(np.int32), 1.0) + det_mask = cv2.resize( + det_mask, dsize=None, fx=self.ds_ratio, fy=self.ds_ratio) + det_mask = np.array(det_mask > 1e-3, dtype='float32') + + f_direction = self.f_direction + skeleton_map = thin(det_mask.astype(np.uint8)) + instance_count, instance_label_map = cv2.connectedComponents( + skeleton_map.astype(np.uint8), connectivity=8) + + ys, xs = np.where(instance_label_map == 1) + pos_list = list(zip(ys, xs)) + if len(pos_list) < 3: + return None + pos_list_sorted = sort_and_expand_with_direction_v2( + pos_list, f_direction, det_mask) + + pos_list_sorted = np.array(pos_list_sorted) + length = len(pos_list_sorted) - 1 + insert_num = 0 + for index in range(length): + stride_y = np.abs(pos_list_sorted[index + insert_num][0] - + pos_list_sorted[index + 1 + insert_num][0]) + stride_x = np.abs(pos_list_sorted[index + insert_num][1] - + pos_list_sorted[index + 1 + insert_num][1]) + max_points = int(max(stride_x, stride_y)) + + stride = (pos_list_sorted[index + insert_num] - + pos_list_sorted[index + 1 + insert_num]) / (max_points) + insert_num_temp = max_points - 1 + + for i in range(int(insert_num_temp)): + insert_value = pos_list_sorted[index + insert_num] - (i + 1 + ) * stride + insert_index = index + i + 1 + insert_num + pos_list_sorted = np.insert( + pos_list_sorted, insert_index, insert_value, axis=0) + insert_num += insert_num_temp + + pos_info = np.array(pos_list_sorted).reshape(-1, 2).astype( + np.float32) # xy-> yx + + point_num = len(pos_info) + if point_num > fixed_point_num: + keep_ids = [ + int((point_num * 1.0 / fixed_point_num) * x) + for x in range(fixed_point_num) + ] + pos_info = pos_info[keep_ids, :] + + keep = int(min(len(pos_info), fixed_point_num)) + reference_width = (np.abs(poly[0, 0, 0] - poly[-1, 1, 0]) + + np.abs(poly[0, 3, 0] - poly[-1, 2, 0])) // 2 + if np.random.rand() < 1: + dh = (np.random.rand(keep) - 0.5) * reference_height + offset = np.random.rand() - 0.5 + dw = np.array([[0, offset * reference_width * 0.2]]) + random_float_h = np.array([1, 0]).reshape([1, 2]) * dh.reshape( + [keep, 1]) + random_float_w = dw.repeat(keep, axis=0) + pos_info += random_float_h + pos_info += random_float_w + pos_info[:, 0] = np.clip(pos_info[:, 0], 0, max_h - 1) + pos_info[:, 1] = np.clip(pos_info[:, 1], 0, max_w - 1) + + # padding to fixed length + pos_l = np.zeros((self.tcl_len, 3), dtype=np.int32) + pos_l[:, 0] = np.ones((self.tcl_len, )) * img_id + pos_m = np.zeros((self.tcl_len, 1), dtype=np.float32) + pos_l[:keep, 1:] = np.round(pos_info).astype(np.int32) + pos_m[:keep] = 1.0 + return pos_l, pos_m + def generate_direction_map(self, poly_quads, n_char, direction_map): """ """ @@ -334,6 +432,7 @@ class PGProcessTrain(object): """ Generate polygon. """ + self.ds_ratio = ds_ratio score_map_big = np.zeros( ( h, @@ -384,7 +483,6 @@ class PGProcessTrain(object): text_label = text_strs[poly_idx] text_label = self.prepare_text_label(text_label, self.Lexicon_Table) - text_label_index_list = [[self.Lexicon_Table.index(c_)] for c_ in text_label if c_ in self.Lexicon_Table] @@ -432,14 +530,30 @@ class PGProcessTrain(object): # pos info average_shrink_height = self.calculate_average_height( stcl_quads) - pos_l, pos_m = self.fit_and_gather_tcl_points_v2( - min_area_quad, - poly, - max_h=h, - max_w=w, - fixed_point_num=64, - img_id=self.img_id, - reference_height=average_shrink_height) + + if self.point_gather_mode == 'align': + self.f_direction = direction_map[:, :, :-1].copy() + pos_res = self.fit_and_gather_tcl_points_v3( + min_area_quad, + stcl_quads, + max_h=h, + max_w=w, + fixed_point_num=64, + img_id=self.img_id, + reference_height=average_shrink_height) + if pos_res is None: + continue + pos_l, pos_m = pos_res[0], pos_res[1] + + else: + pos_l, pos_m = self.fit_and_gather_tcl_points_v2( + min_area_quad, + poly, + max_h=h, + max_w=w, + fixed_point_num=64, + img_id=self.img_id, + reference_height=average_shrink_height) label_l = text_label_index_list if len(text_label_index_list) < 2: @@ -770,27 +884,41 @@ class PGProcessTrain(object): text_polys[:, :, 0] *= asp_wx text_polys[:, :, 1] *= asp_hy - h, w, _ = im.shape - if max(h, w) > 2048: - rd_scale = 2048.0 / max(h, w) - im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale) - text_polys *= rd_scale - h, w, _ = im.shape - if min(h, w) < 16: - return None - - # no background - im, text_polys, text_tags, hv_tags, text_strs = self.crop_area( - im, - text_polys, - text_tags, - hv_tags, - text_strs, - crop_background=False) + if self.use_resize is True: + ori_h, ori_w, _ = im.shape + if max(ori_h, ori_w) < 200: + ratio = 200 / max(ori_h, ori_w) + im = cv2.resize(im, (int(ori_w * ratio), int(ori_h * ratio))) + text_polys[:, :, 0] *= ratio + text_polys[:, :, 1] *= ratio + + if max(ori_h, ori_w) > 512: + ratio = 512 / max(ori_h, ori_w) + im = cv2.resize(im, (int(ori_w * ratio), int(ori_h * ratio))) + text_polys[:, :, 0] *= ratio + text_polys[:, :, 1] *= ratio + elif self.use_random_crop is True: + h, w, _ = im.shape + if max(h, w) > 2048: + rd_scale = 2048.0 / max(h, w) + im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale) + text_polys *= rd_scale + h, w, _ = im.shape + if min(h, w) < 16: + return None + + # no background + im, text_polys, text_tags, hv_tags, text_strs = self.crop_area( + im, + text_polys, + text_tags, + hv_tags, + text_strs, + crop_background=False) if text_polys.shape[0] == 0: return None - # # continue for all ignore case + # continue for all ignore case if np.sum((text_tags * 1.0)) >= text_tags.size: return None new_h, new_w, _ = im.shape diff --git a/ppocr/data/imaug/vqa/__init__.py b/ppocr/data/imaug/vqa/__init__.py index 34189bcefb17a0776bd62a19c58081286882b5a5..73f7dcdf712f6db0ff4354b1b01134d1277ff078 100644 --- a/ppocr/data/imaug/vqa/__init__.py +++ b/ppocr/data/imaug/vqa/__init__.py @@ -12,11 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .token import VQATokenPad, VQASerTokenChunk, VQAReTokenChunk, VQAReTokenRelation +from .token import VQATokenPad, VQASerTokenChunk, VQAReTokenChunk, VQAReTokenRelation, TensorizeEntitiesRelations __all__ = [ - 'VQATokenPad', - 'VQASerTokenChunk', - 'VQAReTokenChunk', - 'VQAReTokenRelation', + 'VQATokenPad', 'VQASerTokenChunk', 'VQAReTokenChunk', 'VQAReTokenRelation', + 'TensorizeEntitiesRelations' ] diff --git a/ppocr/data/imaug/vqa/token/__init__.py b/ppocr/data/imaug/vqa/token/__init__.py index 7c115661753cd031b16ec34697157e2fcdcf2dec..5fbaa43db9e7182cfa3efa4f3dc0d9e54c17c822 100644 --- a/ppocr/data/imaug/vqa/token/__init__.py +++ b/ppocr/data/imaug/vqa/token/__init__.py @@ -15,3 +15,4 @@ from .vqa_token_chunk import VQASerTokenChunk, VQAReTokenChunk from .vqa_token_pad import VQATokenPad from .vqa_token_relation import VQAReTokenRelation +from .vqa_re_convert import TensorizeEntitiesRelations \ No newline at end of file diff --git a/ppocr/data/imaug/vqa/token/vqa_re_convert.py b/ppocr/data/imaug/vqa/token/vqa_re_convert.py new file mode 100644 index 0000000000000000000000000000000000000000..86962f2590b57f38640d76ef5d8b74ead5e854e0 --- /dev/null +++ b/ppocr/data/imaug/vqa/token/vqa_re_convert.py @@ -0,0 +1,51 @@ +# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + + +class TensorizeEntitiesRelations(object): + def __init__(self, max_seq_len=512, infer_mode=False, **kwargs): + self.max_seq_len = max_seq_len + self.infer_mode = infer_mode + + def __call__(self, data): + entities = data['entities'] + relations = data['relations'] + + entities_new = np.full( + shape=[self.max_seq_len + 1, 3], fill_value=-1, dtype='int64') + entities_new[0, 0] = len(entities['start']) + entities_new[0, 1] = len(entities['end']) + entities_new[0, 2] = len(entities['label']) + entities_new[1:len(entities['start']) + 1, 0] = np.array(entities[ + 'start']) + entities_new[1:len(entities['end']) + 1, 1] = np.array(entities['end']) + entities_new[1:len(entities['label']) + 1, 2] = np.array(entities[ + 'label']) + + relations_new = np.full( + shape=[self.max_seq_len * self.max_seq_len + 1, 2], + fill_value=-1, + dtype='int64') + relations_new[0, 0] = len(relations['head']) + relations_new[0, 1] = len(relations['tail']) + relations_new[1:len(relations['head']) + 1, 0] = np.array(relations[ + 'head']) + relations_new[1:len(relations['tail']) + 1, 1] = np.array(relations[ + 'tail']) + + data['entities'] = entities_new + data['relations'] = relations_new + return data diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py index 1a11778945c9d7b5f5519cd55473e8bf7790db2c..02525b3d50ad87509a6cba6fb2c1b00cb0add56e 100755 --- a/ppocr/losses/__init__.py +++ b/ppocr/losses/__init__.py @@ -25,6 +25,7 @@ from .det_east_loss import EASTLoss from .det_sast_loss import SASTLoss from .det_pse_loss import PSELoss from .det_fce_loss import FCELoss +from .det_ct_loss import CTLoss # rec loss from .rec_ctc_loss import CTCLoss @@ -68,7 +69,7 @@ def build_loss(config): 'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss', 'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss', 'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss', - 'SLALoss' + 'SLALoss', 'CTLoss' ] config = copy.deepcopy(config) module_name = config.pop('name') diff --git a/ppocr/losses/det_ct_loss.py b/ppocr/losses/det_ct_loss.py new file mode 100755 index 0000000000000000000000000000000000000000..f48c95be4f84e2d8520363379b3061fa4245c105 --- /dev/null +++ b/ppocr/losses/det_ct_loss.py @@ -0,0 +1,276 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refer from: +https://github.com/shengtao96/CentripetalText/tree/main/models/loss +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import nn +import paddle.nn.functional as F +import numpy as np + + +def ohem_single(score, gt_text, training_mask): + # online hard example mining + + pos_num = int(paddle.sum(gt_text > 0.5)) - int( + paddle.sum((gt_text > 0.5) & (training_mask <= 0.5))) + + if pos_num == 0: + # selected_mask = gt_text.copy() * 0 # may be not good + selected_mask = training_mask + selected_mask = paddle.cast( + selected_mask.reshape( + (1, selected_mask.shape[0], selected_mask.shape[1])), "float32") + return selected_mask + + neg_num = int(paddle.sum((gt_text <= 0.5) & (training_mask > 0.5))) + neg_num = int(min(pos_num * 3, neg_num)) + + if neg_num == 0: + selected_mask = training_mask + selected_mask = paddle.cast( + selected_mask.reshape( + (1, selected_mask.shape[0], selected_mask.shape[1])), "float32") + return selected_mask + + # hard example + neg_score = score[(gt_text <= 0.5) & (training_mask > 0.5)] + neg_score_sorted = paddle.sort(-neg_score) + threshold = -neg_score_sorted[neg_num - 1] + + selected_mask = ((score >= threshold) | + (gt_text > 0.5)) & (training_mask > 0.5) + selected_mask = paddle.cast( + selected_mask.reshape( + (1, selected_mask.shape[0], selected_mask.shape[1])), "float32") + return selected_mask + + +def ohem_batch(scores, gt_texts, training_masks): + selected_masks = [] + for i in range(scores.shape[0]): + selected_masks.append( + ohem_single(scores[i, :, :], gt_texts[i, :, :], training_masks[ + i, :, :])) + + selected_masks = paddle.cast(paddle.concat(selected_masks, 0), "float32") + return selected_masks + + +def iou_single(a, b, mask, n_class): + EPS = 1e-6 + valid = mask == 1 + a = a[valid] + b = b[valid] + miou = [] + + # iou of each class + for i in range(n_class): + inter = paddle.cast(((a == i) & (b == i)), "float32") + union = paddle.cast(((a == i) | (b == i)), "float32") + + miou.append(paddle.sum(inter) / (paddle.sum(union) + EPS)) + miou = sum(miou) / len(miou) + return miou + + +def iou(a, b, mask, n_class=2, reduce=True): + batch_size = a.shape[0] + + a = a.reshape((batch_size, -1)) + b = b.reshape((batch_size, -1)) + mask = mask.reshape((batch_size, -1)) + + iou = paddle.zeros((batch_size, ), dtype="float32") + for i in range(batch_size): + iou[i] = iou_single(a[i], b[i], mask[i], n_class) + + if reduce: + iou = paddle.mean(iou) + return iou + + +class DiceLoss(nn.Layer): + def __init__(self, loss_weight=1.0): + super(DiceLoss, self).__init__() + self.loss_weight = loss_weight + + def forward(self, input, target, mask, reduce=True): + batch_size = input.shape[0] + input = F.sigmoid(input) # scale to 0-1 + + input = input.reshape((batch_size, -1)) + target = paddle.cast(target.reshape((batch_size, -1)), "float32") + mask = paddle.cast(mask.reshape((batch_size, -1)), "float32") + + input = input * mask + target = target * mask + + a = paddle.sum(input * target, axis=1) + b = paddle.sum(input * input, axis=1) + 0.001 + c = paddle.sum(target * target, axis=1) + 0.001 + d = (2 * a) / (b + c) + loss = 1 - d + + loss = self.loss_weight * loss + + if reduce: + loss = paddle.mean(loss) + + return loss + + +class SmoothL1Loss(nn.Layer): + def __init__(self, beta=1.0, loss_weight=1.0): + super(SmoothL1Loss, self).__init__() + self.beta = beta + self.loss_weight = loss_weight + + np_coord = np.zeros(shape=[640, 640, 2], dtype=np.int64) + for i in range(640): + for j in range(640): + np_coord[i, j, 0] = j + np_coord[i, j, 1] = i + np_coord = np_coord.reshape((-1, 2)) + + self.coord = self.create_parameter( + shape=[640 * 640, 2], + dtype="int32", # NOTE: not support "int64" before paddle 2.3.1 + default_initializer=nn.initializer.Assign(value=np_coord)) + self.coord.stop_gradient = True + + def forward_single(self, input, target, mask, beta=1.0, eps=1e-6): + batch_size = input.shape[0] + + diff = paddle.abs(input - target) * mask.unsqueeze(1) + loss = paddle.where(diff < beta, 0.5 * diff * diff / beta, + diff - 0.5 * beta) + loss = paddle.cast(loss.reshape((batch_size, -1)), "float32") + mask = paddle.cast(mask.reshape((batch_size, -1)), "float32") + loss = paddle.sum(loss, axis=-1) + loss = loss / (mask.sum(axis=-1) + eps) + + return loss + + def select_single(self, distance, gt_instance, gt_kernel_instance, + training_mask): + + with paddle.no_grad(): + # paddle 2.3.1, paddle.slice not support: + # distance[:, self.coord[:, 1], self.coord[:, 0]] + select_distance_list = [] + for i in range(2): + tmp1 = distance[i, :] + tmp2 = tmp1[self.coord[:, 1], self.coord[:, 0]] + select_distance_list.append(tmp2.unsqueeze(0)) + select_distance = paddle.concat(select_distance_list, axis=0) + + off_points = paddle.cast( + self.coord, "float32") + 10 * select_distance.transpose((1, 0)) + + off_points = paddle.cast(off_points, "int64") + off_points = paddle.clip(off_points, 0, distance.shape[-1] - 1) + + selected_mask = ( + gt_instance[self.coord[:, 1], self.coord[:, 0]] != + gt_kernel_instance[off_points[:, 1], off_points[:, 0]]) + selected_mask = paddle.cast( + selected_mask.reshape((1, -1, distance.shape[-1])), "int64") + selected_training_mask = selected_mask * training_mask + + return selected_training_mask + + def forward(self, + distances, + gt_instances, + gt_kernel_instances, + training_masks, + gt_distances, + reduce=True): + + selected_training_masks = [] + for i in range(distances.shape[0]): + selected_training_masks.append( + self.select_single(distances[i, :, :, :], gt_instances[i, :, :], + gt_kernel_instances[i, :, :], training_masks[ + i, :, :])) + selected_training_masks = paddle.cast( + paddle.concat(selected_training_masks, 0), "float32") + + loss = self.forward_single(distances, gt_distances, + selected_training_masks, self.beta) + loss = self.loss_weight * loss + + with paddle.no_grad(): + batch_size = distances.shape[0] + false_num = selected_training_masks.reshape((batch_size, -1)) + false_num = false_num.sum(axis=-1) + total_num = paddle.cast( + training_masks.reshape((batch_size, -1)), "float32") + total_num = total_num.sum(axis=-1) + iou_text = (total_num - false_num) / (total_num + 1e-6) + + if reduce: + loss = paddle.mean(loss) + + return loss, iou_text + + +class CTLoss(nn.Layer): + def __init__(self): + super(CTLoss, self).__init__() + self.kernel_loss = DiceLoss() + self.loc_loss = SmoothL1Loss(beta=0.1, loss_weight=0.05) + + def forward(self, preds, batch): + imgs = batch[0] + out = preds['maps'] + gt_kernels, training_masks, gt_instances, gt_kernel_instances, training_mask_distances, gt_distances = batch[ + 1:] + + kernels = out[:, 0, :, :] + distances = out[:, 1:, :, :] + + # kernel loss + selected_masks = ohem_batch(kernels, gt_kernels, training_masks) + + loss_kernel = self.kernel_loss( + kernels, gt_kernels, selected_masks, reduce=False) + + iou_kernel = iou(paddle.cast((kernels > 0), "int64"), + gt_kernels, + training_masks, + reduce=False) + losses = dict(loss_kernels=loss_kernel, ) + + # loc loss + loss_loc, iou_text = self.loc_loss( + distances, + gt_instances, + gt_kernel_instances, + training_mask_distances, + gt_distances, + reduce=False) + losses.update(dict(loss_loc=loss_loc, )) + + loss_all = loss_kernel + loss_loc + losses = {'loss': loss_all} + + return losses diff --git a/ppocr/losses/distillation_loss.py b/ppocr/losses/distillation_loss.py index 87fed6235d73aef2695cd6db95662e615d52c94c..4bfbed75a338e2bd3bca0b80d16028030bf2f0b5 100644 --- a/ppocr/losses/distillation_loss.py +++ b/ppocr/losses/distillation_loss.py @@ -417,11 +417,13 @@ class DistillationVQADistanceLoss(DistanceLoss): mode="l2", model_name_pairs=[], key=None, + index=None, name="loss_distance", **kargs): super().__init__(mode=mode, **kargs) assert isinstance(model_name_pairs, list) self.key = key + self.index = index self.model_name_pairs = model_name_pairs self.name = name + "_l2" @@ -434,6 +436,9 @@ class DistillationVQADistanceLoss(DistanceLoss): if self.key is not None: out1 = out1[self.key] out2 = out2[self.key] + if self.index is not None: + out1 = out1[:, self.index, :, :] + out2 = out2[:, self.index, :, :] if attention_mask is not None: max_len = attention_mask.shape[-1] out1 = out1[:, :max_len] diff --git a/ppocr/losses/e2e_pg_loss.py b/ppocr/losses/e2e_pg_loss.py index 10a8ed0aa907123b155976ba498426604f23c2b0..aff67b7ce3c208bf9c7b1371e095eac8c70ce9df 100644 --- a/ppocr/losses/e2e_pg_loss.py +++ b/ppocr/losses/e2e_pg_loss.py @@ -89,12 +89,13 @@ class PGLoss(nn.Layer): tcl_pos = paddle.reshape(tcl_pos, [-1, 3]) tcl_pos = paddle.cast(tcl_pos, dtype=int) f_tcl_char = paddle.gather_nd(f_char, tcl_pos) - f_tcl_char = paddle.reshape(f_tcl_char, - [-1, 64, 37]) # len(Lexicon_Table)+1 - f_tcl_char_fg, f_tcl_char_bg = paddle.split(f_tcl_char, [36, 1], axis=2) + f_tcl_char = paddle.reshape( + f_tcl_char, [-1, 64, self.pad_num + 1]) # len(Lexicon_Table)+1 + f_tcl_char_fg, f_tcl_char_bg = paddle.split( + f_tcl_char, [self.pad_num, 1], axis=2) f_tcl_char_bg = f_tcl_char_bg * tcl_mask + (1.0 - tcl_mask) * 20.0 b, c, l = tcl_mask.shape - tcl_mask_fg = paddle.expand(x=tcl_mask, shape=[b, c, 36 * l]) + tcl_mask_fg = paddle.expand(x=tcl_mask, shape=[b, c, self.pad_num * l]) tcl_mask_fg.stop_gradient = True f_tcl_char_fg = f_tcl_char_fg * tcl_mask_fg + (1.0 - tcl_mask_fg) * ( -20.0) diff --git a/ppocr/metrics/__init__.py b/ppocr/metrics/__init__.py index 853647c06cf0519a0e049e14c16a0d3e26f9845b..a39d0a464f3f96b44d23cec55768223ca41311fa 100644 --- a/ppocr/metrics/__init__.py +++ b/ppocr/metrics/__init__.py @@ -31,12 +31,14 @@ from .kie_metric import KIEMetric from .vqa_token_ser_metric import VQASerTokenMetric from .vqa_token_re_metric import VQAReTokenMetric from .sr_metric import SRMetric +from .ct_metric import CTMetric + def build_metric(config): support_dict = [ "DetMetric", "DetFCEMetric", "RecMetric", "ClsMetric", "E2EMetric", "DistillationMetric", "TableMetric", 'KIEMetric', 'VQASerTokenMetric', - 'VQAReTokenMetric', 'SRMetric' + 'VQAReTokenMetric', 'SRMetric', 'CTMetric' ] config = copy.deepcopy(config) diff --git a/ppocr/metrics/ct_metric.py b/ppocr/metrics/ct_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..a7634230a23027a5dd5c32a7b8eb87ee4a229076 --- /dev/null +++ b/ppocr/metrics/ct_metric.py @@ -0,0 +1,52 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +from scipy import io +import numpy as np + +from ppocr.utils.e2e_metric.Deteval import combine_results, get_score_C + + +class CTMetric(object): + def __init__(self, main_indicator, delimiter='\t', **kwargs): + self.delimiter = delimiter + self.main_indicator = main_indicator + self.reset() + + def reset(self): + self.results = [] # clear results + + def __call__(self, preds, batch, **kwargs): + # NOTE: only support bs=1 now, as the label length of different sample is Unequal + assert len( + preds) == 1, "CentripetalText test now only suuport batch_size=1." + label = batch[2] + text = batch[3] + pred = preds[0]['points'] + result = get_score_C(label, text, pred) + + self.results.append(result) + + def get_metric(self): + """ + Input format: y0,x0, ..... yn,xn. Each detection is separated by the end of line token ('\n')' + """ + metrics = combine_results(self.results, rec_flag=False) + self.reset() + return metrics diff --git a/ppocr/metrics/vqa_token_re_metric.py b/ppocr/metrics/vqa_token_re_metric.py index f84387d8beb729bcc4b420ceea24a5e9b2993c64..0509984f7e7e85fc1dae859761fedb7356a02477 100644 --- a/ppocr/metrics/vqa_token_re_metric.py +++ b/ppocr/metrics/vqa_token_re_metric.py @@ -37,23 +37,25 @@ class VQAReTokenMetric(object): gt_relations = [] for b in range(len(self.relations_list)): rel_sent = [] - if "head" in self.relations_list[b]: - for head, tail in zip(self.relations_list[b]["head"], - self.relations_list[b]["tail"]): + relation_list = self.relations_list[b] + entitie_list = self.entities_list[b] + head_len = relation_list[0, 0] + if head_len > 0: + entitie_start_list = entitie_list[1:entitie_list[0, 0] + 1, 0] + entitie_end_list = entitie_list[1:entitie_list[0, 1] + 1, 1] + entitie_label_list = entitie_list[1:entitie_list[0, 2] + 1, 2] + for head, tail in zip(relation_list[1:head_len + 1, 0], + relation_list[1:head_len + 1, 1]): rel = {} rel["head_id"] = head - rel["head"] = ( - self.entities_list[b]["start"][rel["head_id"]], - self.entities_list[b]["end"][rel["head_id"]]) - rel["head_type"] = self.entities_list[b]["label"][rel[ - "head_id"]] + rel["head"] = (entitie_start_list[head], + entitie_end_list[head]) + rel["head_type"] = entitie_label_list[head] rel["tail_id"] = tail - rel["tail"] = ( - self.entities_list[b]["start"][rel["tail_id"]], - self.entities_list[b]["end"][rel["tail_id"]]) - rel["tail_type"] = self.entities_list[b]["label"][rel[ - "tail_id"]] + rel["tail"] = (entitie_start_list[tail], + entitie_end_list[tail]) + rel["tail_type"] = entitie_label_list[tail] rel["type"] = 1 rel_sent.append(rel) diff --git a/ppocr/modeling/backbones/vqa_layoutlm.py b/ppocr/modeling/backbones/vqa_layoutlm.py index 8e10ed7b48e9aff344b71e5a04970d1a5dab8a71..acb1315cc0a588396549e5b8928bd2e4d3c769be 100644 --- a/ppocr/modeling/backbones/vqa_layoutlm.py +++ b/ppocr/modeling/backbones/vqa_layoutlm.py @@ -218,8 +218,12 @@ class LayoutXLMForRe(NLPBaseModel): def forward(self, x): if self.use_visual_backbone is True: image = x[4] + entities = x[5] + relations = x[6] else: image = None + entities = x[4] + relations = x[5] x = self.model( input_ids=x[0], bbox=x[1], @@ -229,6 +233,6 @@ class LayoutXLMForRe(NLPBaseModel): position_ids=None, head_mask=None, labels=None, - entities=x[5], - relations=x[6]) + entities=entities, + relations=relations) return x diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py index 0feda6c6e062fa314d97b8949d8545ed3305c22e..751757e5f176119688e2db47a68c514850b91823 100755 --- a/ppocr/modeling/heads/__init__.py +++ b/ppocr/modeling/heads/__init__.py @@ -23,6 +23,7 @@ def build_head(config): from .det_pse_head import PSEHead from .det_fce_head import FCEHead from .e2e_pg_head import PGHead + from .det_ct_head import CT_Head # rec head from .rec_ctc_head import CTCHead @@ -52,7 +53,7 @@ def build_head(config): 'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer', 'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead', 'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead', - 'VLHead', 'SLAHead', 'RobustScannerHead' + 'VLHead', 'SLAHead', 'RobustScannerHead', 'CT_Head' ] #table head diff --git a/ppocr/modeling/heads/det_ct_head.py b/ppocr/modeling/heads/det_ct_head.py new file mode 100644 index 0000000000000000000000000000000000000000..08e6719e8f0ade6887eb4ad7f44a2bc36ec132db --- /dev/null +++ b/ppocr/modeling/heads/det_ct_head.py @@ -0,0 +1,69 @@ +# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import paddle +from paddle import nn +import paddle.nn.functional as F +from paddle import ParamAttr + +import math +from paddle.nn.initializer import TruncatedNormal, Constant, Normal +ones_ = Constant(value=1.) +zeros_ = Constant(value=0.) + + +class CT_Head(nn.Layer): + def __init__(self, + in_channels, + hidden_dim, + num_classes, + loss_kernel=None, + loss_loc=None): + super(CT_Head, self).__init__() + self.conv1 = nn.Conv2D( + in_channels, hidden_dim, kernel_size=3, stride=1, padding=1) + self.bn1 = nn.BatchNorm2D(hidden_dim) + self.relu1 = nn.ReLU() + + self.conv2 = nn.Conv2D( + hidden_dim, num_classes, kernel_size=1, stride=1, padding=0) + + for m in self.sublayers(): + if isinstance(m, nn.Conv2D): + n = m._kernel_size[0] * m._kernel_size[1] * m._out_channels + normal_ = Normal(mean=0.0, std=math.sqrt(2. / n)) + normal_(m.weight) + elif isinstance(m, nn.BatchNorm2D): + zeros_(m.bias) + ones_(m.weight) + + def _upsample(self, x, scale=1): + return F.upsample(x, scale_factor=scale, mode='bilinear') + + def forward(self, f, targets=None): + out = self.conv1(f) + out = self.relu1(self.bn1(out)) + out = self.conv2(out) + + if self.training: + out = self._upsample(out, scale=4) + return {'maps': out} + else: + score = F.sigmoid(out[:, 0, :, :]) + return {'maps': out, 'score': score} diff --git a/ppocr/modeling/heads/e2e_pg_head.py b/ppocr/modeling/heads/e2e_pg_head.py index 274e1cdac5172f45590c9f7d7b50522c74db6750..514962ef97e503d331b6351c6d314070dfd8b15f 100644 --- a/ppocr/modeling/heads/e2e_pg_head.py +++ b/ppocr/modeling/heads/e2e_pg_head.py @@ -66,8 +66,17 @@ class PGHead(nn.Layer): """ """ - def __init__(self, in_channels, **kwargs): + def __init__(self, + in_channels, + character_dict_path='ppocr/utils/ic15_dict.txt', + **kwargs): super(PGHead, self).__init__() + + # get character_length + with open(character_dict_path, "rb") as fin: + lines = fin.readlines() + character_length = len(lines) + 1 + self.conv_f_score1 = ConvBNLayer( in_channels=in_channels, out_channels=64, @@ -178,7 +187,7 @@ class PGHead(nn.Layer): name="conv_f_char{}".format(5)) self.conv3 = nn.Conv2D( in_channels=256, - out_channels=37, + out_channels=character_length, kernel_size=3, stride=1, padding=1, diff --git a/ppocr/modeling/heads/table_att_head.py b/ppocr/modeling/heads/table_att_head.py index 00b434105bd9fe1f0d928c5f026dc5804b33fe23..50910c5b73aa2a41f329d7222fc8c632509b4c91 100644 --- a/ppocr/modeling/heads/table_att_head.py +++ b/ppocr/modeling/heads/table_att_head.py @@ -16,6 +16,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import math import paddle import paddle.nn as nn from paddle import ParamAttr @@ -42,7 +43,6 @@ class TableAttentionHead(nn.Layer): def __init__(self, in_channels, hidden_size, - loc_type, in_max_len=488, max_text_length=800, out_channels=30, @@ -57,20 +57,16 @@ class TableAttentionHead(nn.Layer): self.structure_attention_cell = AttentionGRUCell( self.input_size, hidden_size, self.out_channels, use_gru=False) self.structure_generator = nn.Linear(hidden_size, self.out_channels) - self.loc_type = loc_type self.in_max_len = in_max_len - if self.loc_type == 1: - self.loc_generator = nn.Linear(hidden_size, 4) + if self.in_max_len == 640: + self.loc_fea_trans = nn.Linear(400, self.max_text_length + 1) + elif self.in_max_len == 800: + self.loc_fea_trans = nn.Linear(625, self.max_text_length + 1) else: - if self.in_max_len == 640: - self.loc_fea_trans = nn.Linear(400, self.max_text_length + 1) - elif self.in_max_len == 800: - self.loc_fea_trans = nn.Linear(625, self.max_text_length + 1) - else: - self.loc_fea_trans = nn.Linear(256, self.max_text_length + 1) - self.loc_generator = nn.Linear(self.input_size + hidden_size, - loc_reg_num) + self.loc_fea_trans = nn.Linear(256, self.max_text_length + 1) + self.loc_generator = nn.Linear(self.input_size + hidden_size, + loc_reg_num) def _char_to_onehot(self, input_char, onehot_dim): input_ont_hot = F.one_hot(input_char, onehot_dim) @@ -80,16 +76,13 @@ class TableAttentionHead(nn.Layer): # if and else branch are both needed when you want to assign a variable # if you modify the var in just one branch, then the modification will not work. fea = inputs[-1] - if len(fea.shape) == 3: - pass - else: - last_shape = int(np.prod(fea.shape[2:])) # gry added - fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape]) - fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels) + last_shape = int(np.prod(fea.shape[2:])) # gry added + fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape]) + fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels) batch_size = fea.shape[0] hidden = paddle.zeros((batch_size, self.hidden_size)) - output_hiddens = [] + output_hiddens = paddle.zeros((batch_size, self.max_text_length + 1, self.hidden_size)) if self.training and targets is not None: structure = targets[0] for i in range(self.max_text_length + 1): @@ -97,7 +90,8 @@ class TableAttentionHead(nn.Layer): structure[:, i], onehot_dim=self.out_channels) (outputs, hidden), alpha = self.structure_attention_cell( hidden, fea, elem_onehots) - output_hiddens.append(paddle.unsqueeze(outputs, axis=1)) + output_hiddens[:, i, :] = outputs + # output_hiddens.append(paddle.unsqueeze(outputs, axis=1)) output = paddle.concat(output_hiddens, axis=1) structure_probs = self.structure_generator(output) if self.loc_type == 1: @@ -118,30 +112,25 @@ class TableAttentionHead(nn.Layer): outputs = None alpha = None max_text_length = paddle.to_tensor(self.max_text_length) - i = 0 - while i < max_text_length + 1: + for i in range(max_text_length + 1): elem_onehots = self._char_to_onehot( temp_elem, onehot_dim=self.out_channels) (outputs, hidden), alpha = self.structure_attention_cell( hidden, fea, elem_onehots) - output_hiddens.append(paddle.unsqueeze(outputs, axis=1)) + output_hiddens[:, i, :] = outputs + # output_hiddens.append(paddle.unsqueeze(outputs, axis=1)) structure_probs_step = self.structure_generator(outputs) temp_elem = structure_probs_step.argmax(axis=1, dtype="int32") - i += 1 - output = paddle.concat(output_hiddens, axis=1) + output = output_hiddens structure_probs = self.structure_generator(output) structure_probs = F.softmax(structure_probs) - if self.loc_type == 1: - loc_preds = self.loc_generator(output) - loc_preds = F.sigmoid(loc_preds) - else: - loc_fea = fea.transpose([0, 2, 1]) - loc_fea = self.loc_fea_trans(loc_fea) - loc_fea = loc_fea.transpose([0, 2, 1]) - loc_concat = paddle.concat([output, loc_fea], axis=2) - loc_preds = self.loc_generator(loc_concat) - loc_preds = F.sigmoid(loc_preds) + loc_fea = fea.transpose([0, 2, 1]) + loc_fea = self.loc_fea_trans(loc_fea) + loc_fea = loc_fea.transpose([0, 2, 1]) + loc_concat = paddle.concat([output, loc_fea], axis=2) + loc_preds = self.loc_generator(loc_concat) + loc_preds = F.sigmoid(loc_preds) return {'structure_probs': structure_probs, 'loc_preds': loc_preds} @@ -166,6 +155,7 @@ class SLAHead(nn.Layer): self.max_text_length = max_text_length self.emb = self._char_to_onehot self.num_embeddings = out_channels + self.loc_reg_num = loc_reg_num # structure self.structure_attention_cell = AttentionGRUCell( @@ -213,15 +203,17 @@ class SLAHead(nn.Layer): fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels) hidden = paddle.zeros((batch_size, self.hidden_size)) - structure_preds = [] - loc_preds = [] + structure_preds = paddle.zeros((batch_size, self.max_text_length + 1, self.num_embeddings)) + loc_preds = paddle.zeros((batch_size, self.max_text_length + 1, self.loc_reg_num)) + structure_preds.stop_gradient = True + loc_preds.stop_gradient = True if self.training and targets is not None: structure = targets[0] for i in range(self.max_text_length + 1): hidden, structure_step, loc_step = self._decode(structure[:, i], fea, hidden) - structure_preds.append(structure_step) - loc_preds.append(loc_step) + structure_preds[:, i, :] = structure_step + loc_preds[:, i, :] = loc_step else: pre_chars = paddle.zeros(shape=[batch_size], dtype="int32") max_text_length = paddle.to_tensor(self.max_text_length) @@ -231,10 +223,8 @@ class SLAHead(nn.Layer): hidden, structure_step, loc_step = self._decode(pre_chars, fea, hidden) pre_chars = structure_step.argmax(axis=1, dtype="int32") - structure_preds.append(structure_step) - loc_preds.append(loc_step) - structure_preds = paddle.stack(structure_preds, axis=1) - loc_preds = paddle.stack(loc_preds, axis=1) + structure_preds[:, i, :] = structure_step + loc_preds[:, i, :] = loc_step if not self.training: structure_preds = F.softmax(structure_preds) return {'structure_probs': structure_preds, 'loc_preds': loc_preds} diff --git a/ppocr/modeling/necks/__init__.py b/ppocr/modeling/necks/__init__.py index e3ae2d6ef27821f592645a4ba945d3feeaa8cf8a..c7e8dd068b4a68e56b066ca8fa629644a8f302c6 100644 --- a/ppocr/modeling/necks/__init__.py +++ b/ppocr/modeling/necks/__init__.py @@ -26,13 +26,15 @@ def build_neck(config): from .fce_fpn import FCEFPN from .pren_fpn import PRENFPN from .csp_pan import CSPPAN + from .ct_fpn import CTFPN support_dict = [ 'FPN', 'FCEFPN', 'LKPAN', 'DBFPN', 'RSEFPN', 'EASTFPN', 'SASTFPN', - 'SequenceEncoder', 'PGFPN', 'TableFPN', 'PRENFPN', 'CSPPAN' + 'SequenceEncoder', 'PGFPN', 'TableFPN', 'PRENFPN', 'CSPPAN', 'CTFPN' ] module_name = config.pop('name') assert module_name in support_dict, Exception('neck only support {}'.format( support_dict)) + module_class = eval(module_name)(**config) return module_class diff --git a/ppocr/modeling/necks/ct_fpn.py b/ppocr/modeling/necks/ct_fpn.py new file mode 100644 index 0000000000000000000000000000000000000000..ee4d25e901b5b3093588571f0412a931eaf6f364 --- /dev/null +++ b/ppocr/modeling/necks/ct_fpn.py @@ -0,0 +1,185 @@ +# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import nn +import paddle.nn.functional as F +from paddle import ParamAttr +import os +import sys + +import math +from paddle.nn.initializer import TruncatedNormal, Constant, Normal +ones_ = Constant(value=1.) +zeros_ = Constant(value=0.) + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../../..'))) + + +class Conv_BN_ReLU(nn.Layer): + def __init__(self, + in_planes, + out_planes, + kernel_size=1, + stride=1, + padding=0): + super(Conv_BN_ReLU, self).__init__() + self.conv = nn.Conv2D( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias_attr=False) + self.bn = nn.BatchNorm2D(out_planes) + self.relu = nn.ReLU() + + for m in self.sublayers(): + if isinstance(m, nn.Conv2D): + n = m._kernel_size[0] * m._kernel_size[1] * m._out_channels + normal_ = Normal(mean=0.0, std=math.sqrt(2. / n)) + normal_(m.weight) + elif isinstance(m, nn.BatchNorm2D): + zeros_(m.bias) + ones_(m.weight) + + def forward(self, x): + return self.relu(self.bn(self.conv(x))) + + +class FPEM(nn.Layer): + def __init__(self, in_channels, out_channels): + super(FPEM, self).__init__() + planes = out_channels + self.dwconv3_1 = nn.Conv2D( + planes, + planes, + kernel_size=3, + stride=1, + padding=1, + groups=planes, + bias_attr=False) + self.smooth_layer3_1 = Conv_BN_ReLU(planes, planes) + + self.dwconv2_1 = nn.Conv2D( + planes, + planes, + kernel_size=3, + stride=1, + padding=1, + groups=planes, + bias_attr=False) + self.smooth_layer2_1 = Conv_BN_ReLU(planes, planes) + + self.dwconv1_1 = nn.Conv2D( + planes, + planes, + kernel_size=3, + stride=1, + padding=1, + groups=planes, + bias_attr=False) + self.smooth_layer1_1 = Conv_BN_ReLU(planes, planes) + + self.dwconv2_2 = nn.Conv2D( + planes, + planes, + kernel_size=3, + stride=2, + padding=1, + groups=planes, + bias_attr=False) + self.smooth_layer2_2 = Conv_BN_ReLU(planes, planes) + + self.dwconv3_2 = nn.Conv2D( + planes, + planes, + kernel_size=3, + stride=2, + padding=1, + groups=planes, + bias_attr=False) + self.smooth_layer3_2 = Conv_BN_ReLU(planes, planes) + + self.dwconv4_2 = nn.Conv2D( + planes, + planes, + kernel_size=3, + stride=2, + padding=1, + groups=planes, + bias_attr=False) + self.smooth_layer4_2 = Conv_BN_ReLU(planes, planes) + + def _upsample_add(self, x, y): + return F.upsample(x, scale_factor=2, mode='bilinear') + y + + def forward(self, f1, f2, f3, f4): + # up-down + f3 = self.smooth_layer3_1(self.dwconv3_1(self._upsample_add(f4, f3))) + f2 = self.smooth_layer2_1(self.dwconv2_1(self._upsample_add(f3, f2))) + f1 = self.smooth_layer1_1(self.dwconv1_1(self._upsample_add(f2, f1))) + + # down-up + f2 = self.smooth_layer2_2(self.dwconv2_2(self._upsample_add(f2, f1))) + f3 = self.smooth_layer3_2(self.dwconv3_2(self._upsample_add(f3, f2))) + f4 = self.smooth_layer4_2(self.dwconv4_2(self._upsample_add(f4, f3))) + + return f1, f2, f3, f4 + + +class CTFPN(nn.Layer): + def __init__(self, in_channels, out_channel=128): + super(CTFPN, self).__init__() + self.out_channels = out_channel * 4 + + self.reduce_layer1 = Conv_BN_ReLU(in_channels[0], 128) + self.reduce_layer2 = Conv_BN_ReLU(in_channels[1], 128) + self.reduce_layer3 = Conv_BN_ReLU(in_channels[2], 128) + self.reduce_layer4 = Conv_BN_ReLU(in_channels[3], 128) + + self.fpem1 = FPEM(in_channels=(64, 128, 256, 512), out_channels=128) + self.fpem2 = FPEM(in_channels=(64, 128, 256, 512), out_channels=128) + + def _upsample(self, x, scale=1): + return F.upsample(x, scale_factor=scale, mode='bilinear') + + def forward(self, f): + # # reduce channel + f1 = self.reduce_layer1(f[0]) # N,64,160,160 --> N, 128, 160, 160 + f2 = self.reduce_layer2(f[1]) # N, 128, 80, 80 --> N, 128, 80, 80 + f3 = self.reduce_layer3(f[2]) # N, 256, 40, 40 --> N, 128, 40, 40 + f4 = self.reduce_layer4(f[3]) # N, 512, 20, 20 --> N, 128, 20, 20 + + # FPEM + f1_1, f2_1, f3_1, f4_1 = self.fpem1(f1, f2, f3, f4) + f1_2, f2_2, f3_2, f4_2 = self.fpem2(f1_1, f2_1, f3_1, f4_1) + + # FFM + f1 = f1_1 + f1_2 + f2 = f2_1 + f2_2 + f3 = f3_1 + f3_2 + f4 = f4_1 + f4_2 + + f2 = self._upsample(f2, scale=2) + f3 = self._upsample(f3, scale=4) + f4 = self._upsample(f4, scale=8) + ff = paddle.concat((f1, f2, f3, f4), 1) # N,512, 160,160 + return ff diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index 8f41a005f5b90e7edf11fad80b9b7eac89257160..35b7a6800da422264a796da14236ae8a484c30d9 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -35,6 +35,7 @@ from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess, from .vqa_token_re_layoutlm_postprocess import VQAReTokenLayoutLMPostProcess, DistillationRePostProcess from .table_postprocess import TableMasterLabelDecode, TableLabelDecode from .picodet_postprocess import PicoDetPostProcess +from .ct_postprocess import CTPostProcess def build_post_process(config, global_config=None): @@ -48,7 +49,7 @@ def build_post_process(config, global_config=None): 'DistillationSARLabelDecode', 'ViTSTRLabelDecode', 'ABINetLabelDecode', 'TableMasterLabelDecode', 'SPINLabelDecode', 'DistillationSerPostProcess', 'DistillationRePostProcess', - 'VLLabelDecode', 'PicoDetPostProcess' + 'VLLabelDecode', 'PicoDetPostProcess', 'CTPostProcess' ] if config['name'] == 'PSEPostProcess': diff --git a/ppocr/postprocess/ct_postprocess.py b/ppocr/postprocess/ct_postprocess.py new file mode 100755 index 0000000000000000000000000000000000000000..3ab90be24d65888339698a5abe2ed692ceaab4c7 --- /dev/null +++ b/ppocr/postprocess/ct_postprocess.py @@ -0,0 +1,154 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This code is refered from: +https://github.com/shengtao96/CentripetalText/blob/main/test.py +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import os.path as osp +import numpy as np +import cv2 +import paddle +import pyclipper + + +class CTPostProcess(object): + """ + The post process for Centripetal Text (CT). + """ + + def __init__(self, min_score=0.88, min_area=16, box_type='poly', **kwargs): + self.min_score = min_score + self.min_area = min_area + self.box_type = box_type + + self.coord = np.zeros((2, 300, 300), dtype=np.int32) + for i in range(300): + for j in range(300): + self.coord[0, i, j] = j + self.coord[1, i, j] = i + + def __call__(self, preds, batch): + outs = preds['maps'] + out_scores = preds['score'] + + if isinstance(outs, paddle.Tensor): + outs = outs.numpy() + if isinstance(out_scores, paddle.Tensor): + out_scores = out_scores.numpy() + + batch_size = outs.shape[0] + boxes_batch = [] + for idx in range(batch_size): + bboxes = [] + scores = [] + + img_shape = batch[idx] + + org_img_size = img_shape[:3] + img_shape = img_shape[3:] + img_size = img_shape[:2] + + out = np.expand_dims(outs[idx], axis=0) + outputs = dict() + + score = np.expand_dims(out_scores[idx], axis=0) + + kernel = out[:, 0, :, :] > 0.2 + loc = out[:, 1:, :, :].astype("float32") + + score = score[0].astype(np.float32) + kernel = kernel[0].astype(np.uint8) + loc = loc[0].astype(np.float32) + + label_num, label_kernel = cv2.connectedComponents( + kernel, connectivity=4) + + for i in range(1, label_num): + ind = (label_kernel == i) + if ind.sum( + ) < 10: # pixel number less than 10, treated as background + label_kernel[ind] = 0 + + label = np.zeros_like(label_kernel) + h, w = label_kernel.shape + pixels = self.coord[:, :h, :w].reshape(2, -1) + points = pixels.transpose([1, 0]).astype(np.float32) + + off_points = (points + 10. / 4. * loc[:, pixels[1], pixels[0]].T + ).astype(np.int32) + off_points[:, 0] = np.clip(off_points[:, 0], 0, label.shape[1] - 1) + off_points[:, 1] = np.clip(off_points[:, 1], 0, label.shape[0] - 1) + + label[pixels[1], pixels[0]] = label_kernel[off_points[:, 1], + off_points[:, 0]] + label[label_kernel > 0] = label_kernel[label_kernel > 0] + + score_pocket = [0.0] + for i in range(1, label_num): + ind = (label_kernel == i) + if ind.sum() == 0: + score_pocket.append(0.0) + continue + score_i = np.mean(score[ind]) + score_pocket.append(score_i) + + label_num = np.max(label) + 1 + label = cv2.resize( + label, (img_size[1], img_size[0]), + interpolation=cv2.INTER_NEAREST) + + scale = (float(org_img_size[1]) / float(img_size[1]), + float(org_img_size[0]) / float(img_size[0])) + + for i in range(1, label_num): + ind = (label == i) + points = np.array(np.where(ind)).transpose((1, 0)) + + if points.shape[0] < self.min_area: + continue + + score_i = score_pocket[i] + if score_i < self.min_score: + continue + + if self.box_type == 'rect': + rect = cv2.minAreaRect(points[:, ::-1]) + bbox = cv2.boxPoints(rect) * scale + z = bbox.mean(0) + bbox = z + (bbox - z) * 0.85 + elif self.box_type == 'poly': + binary = np.zeros(label.shape, dtype='uint8') + binary[ind] = 1 + try: + _, contours, _ = cv2.findContours( + binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + except BaseException: + contours, _ = cv2.findContours( + binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + bbox = contours[0] * scale + + bbox = bbox.astype('int32') + bboxes.append(bbox.reshape(-1, 2)) + scores.append(score_i) + + boxes_batch.append({'points': bboxes}) + + return boxes_batch diff --git a/ppocr/postprocess/pg_postprocess.py b/ppocr/postprocess/pg_postprocess.py index 0b1455181fddb0adb5347406bb2eb3093ee6fb30..058cf8b907de296094d3ed2fc7e6981939ced328 100644 --- a/ppocr/postprocess/pg_postprocess.py +++ b/ppocr/postprocess/pg_postprocess.py @@ -30,12 +30,18 @@ class PGPostProcess(object): The post process for PGNet. """ - def __init__(self, character_dict_path, valid_set, score_thresh, mode, + def __init__(self, + character_dict_path, + valid_set, + score_thresh, + mode, + point_gather_mode=None, **kwargs): self.character_dict_path = character_dict_path self.valid_set = valid_set self.score_thresh = score_thresh self.mode = mode + self.point_gather_mode = point_gather_mode # c++ la-nms is faster, but only support python 3.5 self.is_python35 = False @@ -43,8 +49,13 @@ class PGPostProcess(object): self.is_python35 = True def __call__(self, outs_dict, shape_list): - post = PGNet_PostProcess(self.character_dict_path, self.valid_set, - self.score_thresh, outs_dict, shape_list) + post = PGNet_PostProcess( + self.character_dict_path, + self.valid_set, + self.score_thresh, + outs_dict, + shape_list, + point_gather_mode=self.point_gather_mode) if self.mode == 'fast': data = post.pg_postprocess_fast() else: diff --git a/ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py b/ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py index 96c25d9aac01066f7a3841fe61aa7b0fe05041bd..64f7d761950249eaef2946e09365dbaab4d94c6c 100644 --- a/ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py +++ b/ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py @@ -21,18 +21,22 @@ class VQAReTokenLayoutLMPostProcess(object): super(VQAReTokenLayoutLMPostProcess, self).__init__() def __call__(self, preds, label=None, *args, **kwargs): + pred_relations = preds['pred_relations'] + if isinstance(preds['pred_relations'], paddle.Tensor): + pred_relations = pred_relations.numpy() + pred_relations = self.decode_pred(pred_relations) + if label is not None: - return self._metric(preds, label) + return self._metric(pred_relations, label) else: - return self._infer(preds, *args, **kwargs) + return self._infer(pred_relations, *args, **kwargs) - def _metric(self, preds, label): - return preds['pred_relations'], label[6], label[5] + def _metric(self, pred_relations, label): + return pred_relations, label[-1], label[-2] - def _infer(self, preds, *args, **kwargs): + def _infer(self, pred_relations, *args, **kwargs): ser_results = kwargs['ser_results'] entity_idx_dict_batch = kwargs['entity_idx_dict_batch'] - pred_relations = preds['pred_relations'] # merge relations and ocr info results = [] @@ -50,6 +54,24 @@ class VQAReTokenLayoutLMPostProcess(object): results.append(result) return results + def decode_pred(self, pred_relations): + pred_relations_new = [] + for pred_relation in pred_relations: + pred_relation_new = [] + pred_relation = pred_relation[1:pred_relation[0, 0, 0] + 1] + for relation in pred_relation: + relation_new = dict() + relation_new['head_id'] = relation[0, 0] + relation_new['head'] = tuple(relation[1]) + relation_new['head_type'] = relation[2, 0] + relation_new['tail_id'] = relation[3, 0] + relation_new['tail'] = tuple(relation[4]) + relation_new['tail_type'] = relation[5, 0] + relation_new['type'] = relation[6, 0] + pred_relation_new.append(relation_new) + pred_relations_new.append(pred_relation_new) + return pred_relations_new + class DistillationRePostProcess(VQAReTokenLayoutLMPostProcess): """ diff --git a/ppocr/utils/e2e_metric/Deteval.py b/ppocr/utils/e2e_metric/Deteval.py index 45567a7dd2d82b6c583abd4a4eabef52974be081..6ce56eda2aa9f38fdc712d49ae64945c558b418d 100755 --- a/ppocr/utils/e2e_metric/Deteval.py +++ b/ppocr/utils/e2e_metric/Deteval.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import numpy as np import scipy.io as io +import Polygon as plg from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area @@ -269,7 +271,124 @@ def get_socre_B(gt_dir, img_id, pred_dict): return single_data -def combine_results(all_data): +def get_score_C(gt_label, text, pred_bboxes): + """ + get score for CentripetalText (CT) prediction. + """ + + def gt_reading_mod(gt_label, text): + """This helper reads groundtruths from mat files""" + groundtruths = [] + nbox = len(gt_label) + for i in range(nbox): + label = {"transcription": text[i][0], "points": gt_label[i].numpy()} + groundtruths.append(label) + + return groundtruths + + def get_union(pD, pG): + areaA = pD.area() + areaB = pG.area() + return areaA + areaB - get_intersection(pD, pG) + + def get_intersection(pD, pG): + pInt = pD & pG + if len(pInt) == 0: + return 0 + return pInt.area() + + def detection_filtering(detections, groundtruths, threshold=0.5): + for gt in groundtruths: + point_num = gt['points'].shape[1] // 2 + if gt['transcription'] == '###' and (point_num > 1): + gt_p = np.array(gt['points']).reshape(point_num, + 2).astype('int32') + gt_p = plg.Polygon(gt_p) + + for det_id, detection in enumerate(detections): + det_y = detection[0::2] + det_x = detection[1::2] + + det_p = np.concatenate((np.array(det_x), np.array(det_y))) + det_p = det_p.reshape(2, -1).transpose() + det_p = plg.Polygon(det_p) + + try: + det_gt_iou = get_intersection(det_p, + gt_p) / det_p.area() + except: + print(det_x, det_y, gt_p) + if det_gt_iou > threshold: + detections[det_id] = [] + + detections[:] = [item for item in detections if item != []] + return detections + + def sigma_calculation(det_p, gt_p): + """ + sigma = inter_area / gt_area + """ + if gt_p.area() == 0.: + return 0 + return get_intersection(det_p, gt_p) / gt_p.area() + + def tau_calculation(det_p, gt_p): + """ + tau = inter_area / det_area + """ + if det_p.area() == 0.: + return 0 + return get_intersection(det_p, gt_p) / det_p.area() + + detections = [] + + for item in pred_bboxes: + detections.append(item[:, ::-1].reshape(-1)) + + groundtruths = gt_reading_mod(gt_label, text) + + detections = detection_filtering( + detections, groundtruths) # filters detections overlapping with DC area + + for idx in range(len(groundtruths) - 1, -1, -1): + #NOTE: source code use 'orin' to indicate '#', here we use 'anno', + # which may cause slight drop in fscore, about 0.12 + if groundtruths[idx]['transcription'] == '###': + groundtruths.pop(idx) + + local_sigma_table = np.zeros((len(groundtruths), len(detections))) + local_tau_table = np.zeros((len(groundtruths), len(detections))) + + for gt_id, gt in enumerate(groundtruths): + if len(detections) > 0: + for det_id, detection in enumerate(detections): + point_num = gt['points'].shape[1] // 2 + + gt_p = np.array(gt['points']).reshape(point_num, + 2).astype('int32') + gt_p = plg.Polygon(gt_p) + + det_y = detection[0::2] + det_x = detection[1::2] + + det_p = np.concatenate((np.array(det_x), np.array(det_y))) + + det_p = det_p.reshape(2, -1).transpose() + det_p = plg.Polygon(det_p) + + local_sigma_table[gt_id, det_id] = sigma_calculation(det_p, + gt_p) + local_tau_table[gt_id, det_id] = tau_calculation(det_p, gt_p) + + data = {} + data['sigma'] = local_sigma_table + data['global_tau'] = local_tau_table + data['global_pred_str'] = '' + data['global_gt_str'] = '' + return data + + +def combine_results(all_data, rec_flag=True): tr = 0.7 tp = 0.6 fsc_k = 0.8 @@ -278,6 +397,7 @@ def combine_results(all_data): global_tau = [] global_pred_str = [] global_gt_str = [] + for data in all_data: global_sigma.append(data['sigma']) global_tau.append(data['global_tau']) @@ -294,7 +414,7 @@ def combine_results(all_data): def one_to_one(local_sigma_table, local_tau_table, local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, - gt_flag, det_flag, idy): + gt_flag, det_flag, idy, rec_flag): hit_str_num = 0 for gt_id in range(num_gt): gt_matching_qualified_sigma_candidates = np.where( @@ -328,14 +448,15 @@ def combine_results(all_data): gt_flag[0, gt_id] = 1 matched_det_id = np.where(local_sigma_table[gt_id, :] > tr) # recg start - gt_str_cur = global_gt_str[idy][gt_id] - pred_str_cur = global_pred_str[idy][matched_det_id[0].tolist()[ - 0]] - if pred_str_cur == gt_str_cur: - hit_str_num += 1 - else: - if pred_str_cur.lower() == gt_str_cur.lower(): + if rec_flag: + gt_str_cur = global_gt_str[idy][gt_id] + pred_str_cur = global_pred_str[idy][matched_det_id[0] + .tolist()[0]] + if pred_str_cur == gt_str_cur: hit_str_num += 1 + else: + if pred_str_cur.lower() == gt_str_cur.lower(): + hit_str_num += 1 # recg end det_flag[0, matched_det_id] = 1 return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num @@ -343,7 +464,7 @@ def combine_results(all_data): def one_to_many(local_sigma_table, local_tau_table, local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, - gt_flag, det_flag, idy): + gt_flag, det_flag, idy, rec_flag): hit_str_num = 0 for gt_id in range(num_gt): # skip the following if the groundtruth was matched @@ -374,28 +495,30 @@ def combine_results(all_data): gt_flag[0, gt_id] = 1 det_flag[0, qualified_tau_candidates] = 1 # recg start - gt_str_cur = global_gt_str[idy][gt_id] - pred_str_cur = global_pred_str[idy][ - qualified_tau_candidates[0].tolist()[0]] - if pred_str_cur == gt_str_cur: - hit_str_num += 1 - else: - if pred_str_cur.lower() == gt_str_cur.lower(): + if rec_flag: + gt_str_cur = global_gt_str[idy][gt_id] + pred_str_cur = global_pred_str[idy][ + qualified_tau_candidates[0].tolist()[0]] + if pred_str_cur == gt_str_cur: hit_str_num += 1 + else: + if pred_str_cur.lower() == gt_str_cur.lower(): + hit_str_num += 1 # recg end elif (np.sum(local_sigma_table[gt_id, qualified_tau_candidates]) >= tr): gt_flag[0, gt_id] = 1 det_flag[0, qualified_tau_candidates] = 1 # recg start - gt_str_cur = global_gt_str[idy][gt_id] - pred_str_cur = global_pred_str[idy][ - qualified_tau_candidates[0].tolist()[0]] - if pred_str_cur == gt_str_cur: - hit_str_num += 1 - else: - if pred_str_cur.lower() == gt_str_cur.lower(): + if rec_flag: + gt_str_cur = global_gt_str[idy][gt_id] + pred_str_cur = global_pred_str[idy][ + qualified_tau_candidates[0].tolist()[0]] + if pred_str_cur == gt_str_cur: hit_str_num += 1 + else: + if pred_str_cur.lower() == gt_str_cur.lower(): + hit_str_num += 1 # recg end global_accumulative_recall = global_accumulative_recall + fsc_k @@ -409,7 +532,7 @@ def combine_results(all_data): def many_to_one(local_sigma_table, local_tau_table, local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, - gt_flag, det_flag, idy): + gt_flag, det_flag, idy, rec_flag): hit_str_num = 0 for det_id in range(num_det): # skip the following if the detection was matched @@ -440,6 +563,30 @@ def combine_results(all_data): gt_flag[0, qualified_sigma_candidates] = 1 det_flag[0, det_id] = 1 # recg start + if rec_flag: + pred_str_cur = global_pred_str[idy][det_id] + gt_len = len(qualified_sigma_candidates[0]) + for idx in range(gt_len): + ele_gt_id = qualified_sigma_candidates[ + 0].tolist()[idx] + if ele_gt_id not in global_gt_str[idy]: + continue + gt_str_cur = global_gt_str[idy][ele_gt_id] + if pred_str_cur == gt_str_cur: + hit_str_num += 1 + break + else: + if pred_str_cur.lower() == gt_str_cur.lower( + ): + hit_str_num += 1 + break + # recg end + elif (np.sum(local_tau_table[qualified_sigma_candidates, + det_id]) >= tp): + det_flag[0, det_id] = 1 + gt_flag[0, qualified_sigma_candidates] = 1 + # recg start + if rec_flag: pred_str_cur = global_pred_str[idy][det_id] gt_len = len(qualified_sigma_candidates[0]) for idx in range(gt_len): @@ -454,27 +601,7 @@ def combine_results(all_data): else: if pred_str_cur.lower() == gt_str_cur.lower(): hit_str_num += 1 - break - # recg end - elif (np.sum(local_tau_table[qualified_sigma_candidates, - det_id]) >= tp): - det_flag[0, det_id] = 1 - gt_flag[0, qualified_sigma_candidates] = 1 - # recg start - pred_str_cur = global_pred_str[idy][det_id] - gt_len = len(qualified_sigma_candidates[0]) - for idx in range(gt_len): - ele_gt_id = qualified_sigma_candidates[0].tolist()[idx] - if ele_gt_id not in global_gt_str[idy]: - continue - gt_str_cur = global_gt_str[idy][ele_gt_id] - if pred_str_cur == gt_str_cur: - hit_str_num += 1 - break - else: - if pred_str_cur.lower() == gt_str_cur.lower(): - hit_str_num += 1 - break + break # recg end global_accumulative_recall = global_accumulative_recall + num_qualified_sigma_candidates * fsc_k @@ -504,7 +631,7 @@ def combine_results(all_data): gt_flag, det_flag, hit_str_num = one_to_one(local_sigma_table, local_tau_table, local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, - gt_flag, det_flag, idx) + gt_flag, det_flag, idx, rec_flag) hit_str_count += hit_str_num #######then check for one-to-many case########## @@ -512,14 +639,14 @@ def combine_results(all_data): gt_flag, det_flag, hit_str_num = one_to_many(local_sigma_table, local_tau_table, local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, - gt_flag, det_flag, idx) + gt_flag, det_flag, idx, rec_flag) hit_str_count += hit_str_num #######then check for many-to-one case########## local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \ gt_flag, det_flag, hit_str_num = many_to_one(local_sigma_table, local_tau_table, local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, - gt_flag, det_flag, idx) + gt_flag, det_flag, idx, rec_flag) hit_str_count += hit_str_num try: diff --git a/ppocr/utils/e2e_utils/extract_textpoint_fast.py b/ppocr/utils/e2e_utils/extract_textpoint_fast.py index 787cd3017fafa6fc554bead0cc05b5bfe682df42..a85b8e78ead00e64630b57400b9e5141eb0181a8 100644 --- a/ppocr/utils/e2e_utils/extract_textpoint_fast.py +++ b/ppocr/utils/e2e_utils/extract_textpoint_fast.py @@ -88,8 +88,35 @@ def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True): return dst_str, keep_idx_list -def instance_ctc_greedy_decoder(gather_info, logits_map, pts_num=4): +def instance_ctc_greedy_decoder(gather_info, + logits_map, + pts_num=4, + point_gather_mode=None): _, _, C = logits_map.shape + if point_gather_mode == 'align': + insert_num = 0 + gather_info = np.array(gather_info) + length = len(gather_info) - 1 + for index in range(length): + stride_y = np.abs(gather_info[index + insert_num][0] - gather_info[ + index + 1 + insert_num][0]) + stride_x = np.abs(gather_info[index + insert_num][1] - gather_info[ + index + 1 + insert_num][1]) + max_points = int(max(stride_x, stride_y)) + stride = (gather_info[index + insert_num] - + gather_info[index + 1 + insert_num]) / (max_points) + insert_num_temp = max_points - 1 + + for i in range(int(insert_num_temp)): + insert_value = gather_info[index + insert_num] - (i + 1 + ) * stride + insert_index = index + i + 1 + insert_num + gather_info = np.insert( + gather_info, insert_index, insert_value, axis=0) + insert_num += insert_num_temp + gather_info = gather_info.tolist() + else: + pass ys, xs = zip(*gather_info) logits_seq = logits_map[list(ys), list(xs)] probs_seq = logits_seq @@ -104,7 +131,8 @@ def instance_ctc_greedy_decoder(gather_info, logits_map, pts_num=4): def ctc_decoder_for_image(gather_info_list, logits_map, Lexicon_Table, - pts_num=6): + pts_num=6, + point_gather_mode=None): """ CTC decoder using multiple processes. """ @@ -114,7 +142,10 @@ def ctc_decoder_for_image(gather_info_list, if len(gather_info) < pts_num: continue dst_str, xys_list = instance_ctc_greedy_decoder( - gather_info, logits_map, pts_num=pts_num) + gather_info, + logits_map, + pts_num=pts_num, + point_gather_mode=point_gather_mode) dst_str_readable = ''.join([Lexicon_Table[idx] for idx in dst_str]) if len(dst_str_readable) < 2: continue @@ -356,7 +387,8 @@ def generate_pivot_list_fast(p_score, p_char_maps, f_direction, Lexicon_Table, - score_thresh=0.5): + score_thresh=0.5, + point_gather_mode=None): """ return center point and end point of TCL instance; filter with the char maps; """ @@ -384,7 +416,10 @@ def generate_pivot_list_fast(p_score, p_char_maps = p_char_maps.transpose([1, 2, 0]) decoded_str, keep_yxs_list = ctc_decoder_for_image( - all_pos_yxs, logits_map=p_char_maps, Lexicon_Table=Lexicon_Table) + all_pos_yxs, + logits_map=p_char_maps, + Lexicon_Table=Lexicon_Table, + point_gather_mode=point_gather_mode) return keep_yxs_list, decoded_str diff --git a/ppocr/utils/e2e_utils/pgnet_pp_utils.py b/ppocr/utils/e2e_utils/pgnet_pp_utils.py index a15503c0a88f735cc5f5eef924b0d022e5684eed..06a766b0e714e2792c0b0d3069963de998eb9eb7 100644 --- a/ppocr/utils/e2e_utils/pgnet_pp_utils.py +++ b/ppocr/utils/e2e_utils/pgnet_pp_utils.py @@ -28,13 +28,19 @@ from extract_textpoint_fast import generate_pivot_list_fast, restore_poly class PGNet_PostProcess(object): # two different post-process - def __init__(self, character_dict_path, valid_set, score_thresh, outs_dict, - shape_list): + def __init__(self, + character_dict_path, + valid_set, + score_thresh, + outs_dict, + shape_list, + point_gather_mode=None): self.Lexicon_Table = get_dict(character_dict_path) self.valid_set = valid_set self.score_thresh = score_thresh self.outs_dict = outs_dict self.shape_list = shape_list + self.point_gather_mode = point_gather_mode def pg_postprocess_fast(self): p_score = self.outs_dict['f_score'] @@ -58,7 +64,8 @@ class PGNet_PostProcess(object): p_char, p_direction, self.Lexicon_Table, - score_thresh=self.score_thresh) + score_thresh=self.score_thresh, + point_gather_mode=self.point_gather_mode) poly_list, keep_str_list = restore_poly(instance_yxs_list, seq_strs, p_border, ratio_w, ratio_h, src_w, src_h, self.valid_set) diff --git a/ppstructure/docs/layout/layout.png b/ppstructure/docs/layout/layout.png index da9640e245e34659771353e328bf97da129bd622..66b95486955b5f45f3f0c16e1ed6577914cc2c7c 100644 Binary files a/ppstructure/docs/layout/layout.png and b/ppstructure/docs/layout/layout.png differ diff --git a/ppstructure/docs/models_list.md b/ppstructure/docs/models_list.md index 935d12d756eec467574f9ae32d48c70a3ea054c3..afed95600f0858b1423a105c4f5bcd3e092211ab 100644 --- a/ppstructure/docs/models_list.md +++ b/ppstructure/docs/models_list.md @@ -51,9 +51,9 @@ |模型名称|模型简介 | 推理模型大小| 精度(hmean) | 预测耗时(ms) | 下载地址| | --- | --- | --- |--- |--- | --- | |ser_VI-LayoutXLM_xfund_zh|基于VI-LayoutXLM在xfund中文数据集上训练的SER模型|1.1G| 93.19% | 15.49 | [推理模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_pretrained.tar) | -|re_VI-LayoutXLM_xfund_zh|基于VI-LayoutXLM在xfund中文数据集上训练的RE模型|1.1G| 83.92% | 15.49 |[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar) | +|re_VI-LayoutXLM_xfund_zh|基于VI-LayoutXLM在xfund中文数据集上训练的RE模型|1.1G| 83.92% | 15.49 |[推理模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar) | |ser_LayoutXLM_xfund_zh|基于LayoutXLM在xfund中文数据集上训练的SER模型|1.4G| 90.38% | 19.49 |[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar) | -|re_LayoutXLM_xfund_zh|基于LayoutXLM在xfund中文数据集上训练的RE模型|1.4G| 74.83% | 19.49 |[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) | +|re_LayoutXLM_xfund_zh|基于LayoutXLM在xfund中文数据集上训练的RE模型|1.4G| 74.83% | 19.49 |[推理模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar) | |ser_LayoutLMv2_xfund_zh|基于LayoutLMv2在xfund中文数据集上训练的SER模型|778M| 85.44% | 31.46 |[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLMv2_xfun_zh.tar) | |re_LayoutLMv2_xfund_zh|基于LayoutLMv2在xfun中文数据集上训练的RE模型|765M| 67.77% | 31.46 |[推理模型 coming soon]() / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutLMv2_xfun_zh.tar) | |ser_LayoutLM_xfund_zh|基于LayoutLM在xfund中文数据集上训练的SER模型|430M| 77.31% | - |[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutLM_xfun_zh.tar) | diff --git a/ppstructure/kie/README.md b/ppstructure/kie/README.md index 562ebb9e25b09015150d3265a7b9a6c8c74e7aae..872edb959276e22b22e4b733df44bdb6a6819c98 100644 --- a/ppstructure/kie/README.md +++ b/ppstructure/kie/README.md @@ -172,16 +172,16 @@ If you want to use OCR engine to obtain end-to-end prediction results, you can u # just predict using SER trained model python3 tools/infer_kie_token_ser.py \ -c configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yml \ - -o Architecture.Backbone.checkpoints=./pretrain_models/ser_vi_layoutxlm_xfund_pretrained/best_accuracy \ + -o Architecture.Backbone.checkpoints=./pretrained_model/ser_vi_layoutxlm_xfund_pretrained/best_accuracy \ Global.infer_img=./ppstructure/docs/kie/input/zh_val_42.jpg # predict using SER and RE trained model at the same time python3 ./tools/infer_kie_token_ser_re.py \ -c configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yml \ - -o Architecture.Backbone.checkpoints=./pretrain_models/re_vi_layoutxlm_xfund_pretrained/best_accuracy \ + -o Architecture.Backbone.checkpoints=./pretrained_model/re_vi_layoutxlm_xfund_pretrained/best_accuracy \ Global.infer_img=./train_data/XFUND/zh_val/image/zh_val_42.jpg \ -c_ser configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yml \ - -o_ser Architecture.Backbone.checkpoints=./pretrain_models/ser_vi_layoutxlm_xfund_pretrained/best_accuracy + -o_ser Architecture.Backbone.checkpoints=./pretrained_model/ser_vi_layoutxlm_xfund_pretrained/best_accuracy ``` The visual result images and the predicted text file will be saved in the `Global.save_res_path` directory. @@ -193,33 +193,34 @@ If you want to load the text detection and recognition results collected before, # just predict using SER trained model python3 tools/infer_kie_token_ser.py \ -c configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yml \ - -o Architecture.Backbone.checkpoints=./pretrain_models/ser_vi_layoutxlm_xfund_pretrained/best_accuracy \ + -o Architecture.Backbone.checkpoints=./pretrained_model/ser_vi_layoutxlm_xfund_pretrained/best_accuracy \ Global.infer_img=./train_data/XFUND/zh_val/val.json \ Global.infer_mode=False # predict using SER and RE trained model at the same time python3 ./tools/infer_kie_token_ser_re.py \ -c configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yml \ - -o Architecture.Backbone.checkpoints=./pretrain_models/re_vi_layoutxlm_xfund_pretrained/best_accuracy \ + -o Architecture.Backbone.checkpoints=./pretrained_model/re_vi_layoutxlm_xfund_pretrained/best_accuracy \ Global.infer_img=./train_data/XFUND/zh_val/val.json \ Global.infer_mode=False \ -c_ser configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yml \ - -o_ser Architecture.Backbone.checkpoints=./pretrain_models/ser_vi_layoutxlm_xfund_pretrained/best_accuracy + -o_ser Architecture.Backbone.checkpoints=./pretrained_model/ser_vi_layoutxlm_xfund_pretrained/best_accuracy ``` #### 4.2.3 Inference using PaddleInference -At present, only SER model supports inference using PaddleInference. - Firstly, download the inference SER inference model. - ```bash mkdir inference cd inference wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar && tar -xf ser_vi_layoutxlm_xfund_infer.tar +wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_infer.tar && tar -xf re_vi_layoutxlm_xfund_infer.tar +cd .. ``` +- SER + Use the following command for inference. @@ -236,6 +237,26 @@ python3 kie/predict_kie_token_ser.py \ The visual results and text file will be saved in directory `output`. +- RE + +Use the following command for inference. + + +```bash +cd ppstructure +python3 kie/predict_kie_token_ser_re.py \ + --kie_algorithm=LayoutXLM \ + --re_model_dir=../inference/re_vi_layoutxlm_xfund_infer \ + --ser_model_dir=../inference/ser_vi_layoutxlm_xfund_infer \ + --use_visual_backbone=False \ + --image_dir=./docs/kie/input/zh_val_42.jpg \ + --ser_dict_path=../train_data/XFUND/class_list_xfun.txt \ + --vis_font_path=../doc/fonts/simfang.ttf \ + --ocr_order_method="tb-yx" +``` + +The visual results and text file will be saved in directory `output`. + ### 4.3 More diff --git a/ppstructure/kie/README_ch.md b/ppstructure/kie/README_ch.md index 56c99ab73abe2b33ccfa18d4181312cd5f4d3622..7a8b1942b1849834f8843c8f272ce08e95f4b993 100644 --- a/ppstructure/kie/README_ch.md +++ b/ppstructure/kie/README_ch.md @@ -156,16 +156,16 @@ wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layou # 仅预测SER模型 python3 tools/infer_kie_token_ser.py \ -c configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yml \ - -o Architecture.Backbone.checkpoints=./pretrain_models/ser_vi_layoutxlm_xfund_pretrained/best_accuracy \ + -o Architecture.Backbone.checkpoints=./pretrained_model/ser_vi_layoutxlm_xfund_pretrained/best_accuracy \ Global.infer_img=./ppstructure/docs/kie/input/zh_val_42.jpg # SER + RE模型串联 python3 ./tools/infer_kie_token_ser_re.py \ -c configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yml \ - -o Architecture.Backbone.checkpoints=./pretrain_models/re_vi_layoutxlm_xfund_pretrained/best_accuracy \ + -o Architecture.Backbone.checkpoints=./pretrained_model/re_vi_layoutxlm_xfund_pretrained/best_accuracy \ Global.infer_img=./train_data/XFUND/zh_val/image/zh_val_42.jpg \ -c_ser configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yml \ - -o_ser Architecture.Backbone.checkpoints=./pretrain_models/ser_vi_layoutxlm_xfund_pretrained/best_accuracy + -o_ser Architecture.Backbone.checkpoints=./pretrained_model/ser_vi_layoutxlm_xfund_pretrained/best_accuracy ``` `Global.save_res_path`目录中会保存可视化的结果图像以及预测的文本文件。 @@ -177,33 +177,34 @@ python3 ./tools/infer_kie_token_ser_re.py \ # 仅预测SER模型 python3 tools/infer_kie_token_ser.py \ -c configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yml \ - -o Architecture.Backbone.checkpoints=./pretrain_models/ser_vi_layoutxlm_xfund_pretrained/best_accuracy \ + -o Architecture.Backbone.checkpoints=./pretrained_model/ser_vi_layoutxlm_xfund_pretrained/best_accuracy \ Global.infer_img=./train_data/XFUND/zh_val/val.json \ Global.infer_mode=False # SER + RE模型串联 python3 ./tools/infer_kie_token_ser_re.py \ -c configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yml \ - -o Architecture.Backbone.checkpoints=./pretrain_models/re_vi_layoutxlm_xfund_pretrained/best_accuracy \ + -o Architecture.Backbone.checkpoints=./pretrained_model/re_vi_layoutxlm_xfund_pretrained/best_accuracy \ Global.infer_img=./train_data/XFUND/zh_val/val.json \ Global.infer_mode=False \ -c_ser configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yml \ - -o_ser Architecture.Backbone.checkpoints=./pretrain_models/ser_vi_layoutxlm_xfund_pretrained/best_accuracy + -o_ser Architecture.Backbone.checkpoints=./pretrained_model/ser_vi_layoutxlm_xfund_pretrained/best_accuracy ``` #### 4.2.3 基于PaddleInference的预测 -目前仅SER模型支持PaddleInference推理。 - -首先下载SER的推理模型。 - +首先下载SER和RE的推理模型。 ```bash mkdir inference cd inference wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar && tar -xf ser_vi_layoutxlm_xfund_infer.tar +wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_infer.tar && tar -xf re_vi_layoutxlm_xfund_infer.tar +cd .. ``` +- SER + 执行下面的命令进行预测。 ```bash @@ -219,6 +220,26 @@ python3 kie/predict_kie_token_ser.py \ 可视化结果保存在`output`目录下。 +- RE + +执行下面的命令进行预测。 + +```bash +cd ppstructure +python3 kie/predict_kie_token_ser_re.py \ + --kie_algorithm=LayoutXLM \ + --re_model_dir=../inference/re_vi_layoutxlm_xfund_infer \ + --ser_model_dir=../inference/ser_vi_layoutxlm_xfund_infer \ + --use_visual_backbone=False \ + --image_dir=./docs/kie/input/zh_val_42.jpg \ + --ser_dict_path=../train_data/XFUND/class_list_xfun.txt \ + --vis_font_path=../doc/fonts/simfang.ttf \ + --ocr_order_method="tb-yx" +``` + +可视化结果保存在`output`目录下。 + + ### 4.3 更多 关于KIE模型的训练评估与推理,请参考:[关键信息抽取教程](../../doc/doc_ch/kie.md)。 diff --git a/ppstructure/kie/predict_kie_token_ser.py b/ppstructure/kie/predict_kie_token_ser.py index 48cfc528a28e0a2bdfb51d3a537f26e891ae3286..e570979bcb419edbc2e09e190ae36ec1458c1826 100644 --- a/ppstructure/kie/predict_kie_token_ser.py +++ b/ppstructure/kie/predict_kie_token_ser.py @@ -102,16 +102,18 @@ class SerPredictor(object): ori_im = img.copy() data = {'image': img} data = transform(data, self.preprocess_op) - img = data[0] - if img is None: + if data[0] is None: return None, 0 - img = np.expand_dims(img, axis=0) - img = img.copy() starttime = time.time() + for idx in range(len(data)): + if isinstance(data[idx], np.ndarray): + data[idx] = np.expand_dims(data[idx], axis=0) + else: + data[idx] = [data[idx]] + for idx in range(len(self.input_tensor)): - expand_input = np.expand_dims(data[idx], axis=0) - self.input_tensor[idx].copy_from_cpu(expand_input) + self.input_tensor[idx].copy_from_cpu(data[idx]) self.predictor.run() @@ -122,9 +124,9 @@ class SerPredictor(object): preds = outputs[0] post_result = self.postprocess_op( - preds, segment_offset_ids=[data[6]], ocr_infos=[data[7]]) + preds, segment_offset_ids=data[6], ocr_infos=data[7]) elapse = time.time() - starttime - return post_result, elapse + return post_result, data, elapse def main(args): @@ -145,7 +147,7 @@ def main(args): if img is None: logger.info("error in loading image:{}".format(image_file)) continue - ser_res, elapse = ser_predictor(img) + ser_res, _, elapse = ser_predictor(img) ser_res = ser_res[0] res_str = '{}\t{}\n'.format( diff --git a/ppstructure/kie/predict_kie_token_ser_re.py b/ppstructure/kie/predict_kie_token_ser_re.py new file mode 100644 index 0000000000000000000000000000000000000000..278e08da918ab8f77062b444becd399b4ea2c0b6 --- /dev/null +++ b/ppstructure/kie/predict_kie_token_ser_re.py @@ -0,0 +1,127 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..'))) + +os.environ["FLAGS_allocator_strategy"] = 'auto_growth' + +import cv2 +import json +import numpy as np +import time + +import tools.infer.utility as utility +from tools.infer_kie_token_ser_re import make_input +from ppocr.postprocess import build_post_process +from ppocr.utils.logging import get_logger +from ppocr.utils.visual import draw_re_results +from ppocr.utils.utility import get_image_file_list, check_and_read +from ppstructure.utility import parse_args +from ppstructure.kie.predict_kie_token_ser import SerPredictor + +from paddleocr import PaddleOCR + +logger = get_logger() + + +class SerRePredictor(object): + def __init__(self, args): + self.use_visual_backbone = args.use_visual_backbone + self.ser_engine = SerPredictor(args) + + postprocess_params = {'name': 'VQAReTokenLayoutLMPostProcess'} + self.postprocess_op = build_post_process(postprocess_params) + self.predictor, self.input_tensor, self.output_tensors, self.config = \ + utility.create_predictor(args, 're', logger) + + def __call__(self, img): + ori_im = img.copy() + starttime = time.time() + ser_results, ser_inputs, _ = self.ser_engine(img) + re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results) + if self.use_visual_backbone == False: + re_input.pop(4) + for idx in range(len(self.input_tensor)): + self.input_tensor[idx].copy_from_cpu(re_input[idx]) + + self.predictor.run() + outputs = [] + for output_tensor in self.output_tensors: + output = output_tensor.copy_to_cpu() + outputs.append(output) + preds = dict( + loss=outputs[1], + pred_relations=outputs[2], + hidden_states=outputs[0], ) + + post_result = self.postprocess_op( + preds, + ser_results=ser_results, + entity_idx_dict_batch=entity_idx_dict_batch) + + elapse = time.time() - starttime + return post_result, elapse + + +def main(args): + image_file_list = get_image_file_list(args.image_dir) + ser_predictor = SerRePredictor(args) + count = 0 + total_time = 0 + + os.makedirs(args.output, exist_ok=True) + with open( + os.path.join(args.output, 'infer.txt'), mode='w', + encoding='utf-8') as f_w: + for image_file in image_file_list: + img, flag, _ = check_and_read(image_file) + if not flag: + img = cv2.imread(image_file) + img = img[:, :, ::-1] + if img is None: + logger.info("error in loading image:{}".format(image_file)) + continue + re_res, elapse = ser_predictor(img) + re_res = re_res[0] + + res_str = '{}\t{}\n'.format( + image_file, + json.dumps( + { + "ocr_info": re_res, + }, ensure_ascii=False)) + f_w.write(res_str) + + img_res = draw_re_results( + image_file, re_res, font_path=args.vis_font_path) + + img_save_path = os.path.join( + args.output, + os.path.splitext(os.path.basename(image_file))[0] + + "_ser_re.jpg") + + cv2.imwrite(img_save_path, img_res) + logger.info("save vis result to {}".format(img_save_path)) + if count > 0: + total_time += elapse + count += 1 + logger.info("Predict time of {}: {}".format(image_file, elapse)) + + +if __name__ == "__main__": + main(parse_args()) diff --git a/ppstructure/layout/README.md b/ppstructure/layout/README.md index 84b977fdd760e6de43d355b802731b5d43eb2cf5..6830f8e82153f8ae7d2e798cda6782bc5518da4c 100644 --- a/ppstructure/layout/README.md +++ b/ppstructure/layout/README.md @@ -23,7 +23,7 @@ English | [简体中文](README_ch.md) ## 1. Introduction -Layout analysis refers to the regional division of documents in the form of pictures and the positioning of key areas, such as text, title, table, picture, etc. The layout analysis algorithm is based on the lightweight model PP-picodet of [PaddleDetection]( https://github.com/PaddlePaddle/PaddleDetection ) +Layout analysis refers to the regional division of documents in the form of pictures and the positioning of key areas, such as text, title, table, picture, etc. The layout analysis algorithm is based on the lightweight model PP-picodet of [PaddleDetection]( https://github.com/PaddlePaddle/PaddleDetection ), including English layout analysis, Chinese layout analysis and table layout analysis models. English layout analysis models can detect document layout elements such as text, title, table, figure, list. Chinese layout analysis models can detect document layout elements such as text, figure, figure caption, table, table caption, header, footer, reference, and equation. Table layout analysis models can detect table regions.
@@ -152,7 +152,7 @@ We provide CDLA(Chinese layout analysis), TableBank(Table layout analysis)etc. d | [cTDaR2019_cTDaR](https://cndplab-founder.github.io/cTDaR2019/) | For form detection (TRACKA) and form identification (TRACKB).Image types include historical data sets (beginning with cTDaR_t0, such as CTDAR_T00872.jpg) and modern data sets (beginning with cTDaR_t1, CTDAR_T10482.jpg). | | [IIIT-AR-13K](http://cvit.iiit.ac.in/usodi/iiitar13k.php) | Data sets constructed by manually annotating figures or pages from publicly available annual reports, containing 5 categories:table, figure, natural image, logo, and signature. | | [TableBank](https://github.com/doc-analysis/TableBank) | For table detection and recognition of large datasets, including Word and Latex document formats | -| [CDLA](https://github.com/buptlihang/CDLA) | Chinese document layout analysis data set, for Chinese literature (paper) scenarios, including 10 categories:Table, Figure, Figure caption, Table, Table caption, Header, Footer, Reference, Equation | +| [CDLA](https://github.com/buptlihang/CDLA) | Chinese document layout analysis data set, for Chinese literature (paper) scenarios, including 10 categories:Text, Title, Figure, Figure caption, Table, Table caption, Header, Footer, Reference, Equation | | [DocBank](https://github.com/doc-analysis/DocBank) | Large-scale dataset (500K document pages) constructed using weakly supervised methods for document layout analysis, containing 12 categories:Author, Caption, Date, Equation, Figure, Footer, List, Paragraph, Reference, Section, Table, Title | @@ -175,7 +175,7 @@ If the test image is Chinese, the pre-trained model of Chinese CDLA dataset can ### 5.1. Train -Train: +Start training with the PaddleDetection [layout analysis profile](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5/configs/picodet/legacy_model/application/layout_analysis) * Modify Profile diff --git a/ppstructure/layout/README_ch.md b/ppstructure/layout/README_ch.md index 46d2ba74b2d5c579d4b25cf0cadac22ebc32e5b2..adef46d47389a50bf34500eee1aaf52ff5dfe449 100644 --- a/ppstructure/layout/README_ch.md +++ b/ppstructure/layout/README_ch.md @@ -22,7 +22,7 @@ ## 1. 简介 -版面分析指的是对图片形式的文档进行区域划分,定位其中的关键区域,如文字、标题、表格、图片等。版面分析算法基于[PaddleDetection](https://github.com/PaddlePaddle/PaddleDetection)的轻量模型PP-PicoDet进行开发。 +版面分析指的是对图片形式的文档进行区域划分,定位其中的关键区域,如文字、标题、表格、图片等。版面分析算法基于[PaddleDetection](https://github.com/PaddlePaddle/PaddleDetection)的轻量模型PP-PicoDet进行开发,包含英文、中文、表格版面分析3类模型。其中,英文模型支持Text、Title、Tale、Figure、List5类区域的检测,中文模型支持Text、Title、Figure、Figure caption、Table、Table caption、Header、Footer、Reference、Equation10类区域的检测,表格版面分析支持Table区域的检测,版面分析效果如下图所示:
@@ -152,7 +152,7 @@ json文件包含所有图像的标注,数据以字典嵌套的方式存放, | ------------------------------------------------------------ | ------------------------------------------------------------ | | [cTDaR2019_cTDaR](https://cndplab-founder.github.io/cTDaR2019/) | 用于表格检测(TRACKA)和表格识别(TRACKB)。图片类型包含历史数据集(以cTDaR_t0开头,如cTDaR_t00872.jpg)和现代数据集(以cTDaR_t1开头,cTDaR_t10482.jpg)。 | | [IIIT-AR-13K](http://cvit.iiit.ac.in/usodi/iiitar13k.php) | 手动注释公开的年度报告中的图形或页面而构建的数据集,包含5类:table, figure, natural image, logo, and signature | -| [CDLA](https://github.com/buptlihang/CDLA) | 中文文档版面分析数据集,面向中文文献类(论文)场景,包含10类:Table、Figure、Figure caption、Table、Table caption、Header、Footer、Reference、Equation | +| [CDLA](https://github.com/buptlihang/CDLA) | 中文文档版面分析数据集,面向中文文献类(论文)场景,包含10类:Text、Title、Figure、Figure caption、Table、Table caption、Header、Footer、Reference、Equation | | [TableBank](https://github.com/doc-analysis/TableBank) | 用于表格检测和识别大型数据集,包含Word和Latex2种文档格式 | | [DocBank](https://github.com/doc-analysis/DocBank) | 使用弱监督方法构建的大规模数据集(500K文档页面),用于文档布局分析,包含12类:Author、Caption、Date、Equation、Figure、Footer、List、Paragraph、Reference、Section、Table、Title | @@ -161,7 +161,7 @@ json文件包含所有图像的标注,数据以字典嵌套的方式存放, 提供了训练脚本、评估脚本和预测脚本,本节将以PubLayNet预训练模型为例进行讲解。 -如果不希望训练,直接体验后面的模型评估、预测、动转静、推理的流程,可以下载提供的预训练模型(PubLayNet数据集),并跳过本部分。 +如果不希望训练,直接体验后面的模型评估、预测、动转静、推理的流程,可以下载提供的预训练模型(PubLayNet数据集),并跳过5.1和5.2。 ``` mkdir pretrained_model @@ -176,7 +176,7 @@ wget https://paddleocr.bj.bcebos.com/ppstructure/models/layout/picodet_lcnet_x1_ ### 5.1. 启动训练 -开始训练: +使用PaddleDetection[版面分析配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.5/configs/picodet/legacy_model/application/layout_analysis)启动训练 * 修改配置文件 diff --git a/ppstructure/predict_system.py b/ppstructure/predict_system.py index 71147d3af8ec666d368234270dcb0d16aaf91938..b827314b8911859faa449c3322ceceaf10769cf6 100644 --- a/ppstructure/predict_system.py +++ b/ppstructure/predict_system.py @@ -100,7 +100,8 @@ class StructureSystem(object): '180': cv2.ROTATE_180, '270': cv2.ROTATE_90_CLOCKWISE } - img = cv2.rotate(img, cv_rotate_code[angle]) + if angle in cv_rotate_code: + img = cv2.rotate(img, cv_rotate_code[angle]) toc = time.time() time_dict['image_orientation'] = toc - tic if self.mode == 'structure': @@ -254,8 +255,7 @@ def main(args): if args.recovery and all_res != []: try: - convert_info_docx(img, all_res, save_folder, img_name, - args.save_pdf) + convert_info_docx(img, all_res, save_folder, img_name) except Exception as ex: logger.error("error in layout recovery image:{}, err msg: {}". format(image_file, ex)) diff --git a/ppstructure/recovery/README.md b/ppstructure/recovery/README.md index 011d6e12fda1b09c7a87367fb887a5c99a4ae00a..0e06c65475b67bcdfc119069fa6f6076322c0e99 100644 --- a/ppstructure/recovery/README.md +++ b/ppstructure/recovery/README.md @@ -82,8 +82,11 @@ Through layout analysis, we divided the image/PDF documents into regions, locate We can restore the test picture through the layout information, OCR detection and recognition structure, table information, and saved pictures. -The whl package is also provided for quick use, see [quickstart](../docs/quickstart_en.md) for details. +The whl package is also provided for quick use, follow the above code, for more infomation please refer to [quickstart](../docs/quickstart_en.md) for details. +```bash +paddleocr --image_dir=ppstructure/docs/table/1.png --type=structure --recovery=true --lang='en' +``` ### 3.1 Download models diff --git a/ppstructure/recovery/README_ch.md b/ppstructure/recovery/README_ch.md index fd2e649024ec88e2ea5c88536ccac2e259538886..bc8913adca3385a88cb2decc87fa9acffc707257 100644 --- a/ppstructure/recovery/README_ch.md +++ b/ppstructure/recovery/README_ch.md @@ -83,7 +83,16 @@ python3 -m pip install -r ppstructure/recovery/requirements.txt 我们通过版面信息、OCR检测和识别结构、表格信息、保存的图片,对测试图片进行恢复即可。 -提供如下代码实现版面恢复,也提供了whl包的形式方便快速使用,详见 [quickstart](../docs/quickstart.md)。 +提供如下代码实现版面恢复,也提供了whl包的形式方便快速使用,代码如下,更多信息详见 [quickstart](../docs/quickstart.md)。 + +```bash +# 中文测试图 +paddleocr --image_dir=ppstructure/docs/table/1.png --type=structure --recovery=true +# 英文测试图 +paddleocr --image_dir=ppstructure/docs/table/1.png --type=structure --recovery=true --lang='en' +# pdf测试文件 +paddleocr --image_dir=ppstructure/recovery/UnrealText.pdf --type=structure --recovery=true --lang='en' +``` diff --git a/ppstructure/recovery/recovery_to_doc.py b/ppstructure/recovery/recovery_to_doc.py index 73b497d49d0961b253738eddad49c88c12c13601..1d8f8d9d4babca7410d6625dbeac4c41668f58a7 100644 --- a/ppstructure/recovery/recovery_to_doc.py +++ b/ppstructure/recovery/recovery_to_doc.py @@ -28,7 +28,7 @@ from ppocr.utils.logging import get_logger logger = get_logger() -def convert_info_docx(img, res, save_folder, img_name, save_pdf=False): +def convert_info_docx(img, res, save_folder, img_name): doc = Document() doc.styles['Normal'].font.name = 'Times New Roman' doc.styles['Normal']._element.rPr.rFonts.set(qn('w:eastAsia'), u'宋体') @@ -60,14 +60,9 @@ def convert_info_docx(img, res, save_folder, img_name, save_pdf=False): elif region['type'].lower() == 'title': doc.add_heading(region['res'][0]['text']) elif region['type'].lower() == 'table': - paragraph = doc.add_paragraph() - new_parser = HtmlToDocx() - new_parser.table_style = 'TableGrid' - table = new_parser.handle_table(html=region['res']['html']) - new_table = deepcopy(table) - new_table.alignment = WD_TABLE_ALIGNMENT.CENTER - paragraph.add_run().element.addnext(new_table._tbl) - + parser = HtmlToDocx() + parser.table_style = 'TableGrid' + parser.handle_table(region['res']['html'], doc) else: paragraph = doc.add_paragraph() paragraph_format = paragraph.paragraph_format @@ -82,13 +77,6 @@ def convert_info_docx(img, res, save_folder, img_name, save_pdf=False): doc.save(docx_path) logger.info('docx save to {}'.format(docx_path)) - # save to pdf - if save_pdf: - pdf_path = os.path.join(save_folder, '{}.pdf'.format(img_name)) - from docx2pdf import convert - convert(docx_path, pdf_path) - logger.info('pdf save to {}'.format(pdf_path)) - def sorted_layout_boxes(res, w): """ diff --git a/ppstructure/recovery/requirements.txt b/ppstructure/recovery/requirements.txt index 25e8cdbb0d58b0a243b176f563c66717d6f4c112..7ddc3391338e5a2a87f9cea9fca006dc03da58fb 100644 --- a/ppstructure/recovery/requirements.txt +++ b/ppstructure/recovery/requirements.txt @@ -1,4 +1,3 @@ python-docx -docx2pdf PyMuPDF beautifulsoup4 \ No newline at end of file diff --git a/ppstructure/recovery/table_process.py b/ppstructure/recovery/table_process.py index 243aaf8933791bf4704964d9665173fe70982f95..982e6b760f9291628d0514728dc8f684f183aa2c 100644 --- a/ppstructure/recovery/table_process.py +++ b/ppstructure/recovery/table_process.py @@ -1,4 +1,3 @@ - # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,62 +12,59 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -This code is refer from:https://github.com/pqzx/html2docx/blob/8f6695a778c68befb302e48ac0ed5201ddbd4524/htmldocx/h2d.py - +This code is refer from: https://github.com/weizwx/html2docx/blob/master/htmldocx/h2d.py """ -import re, argparse -import io, os -import urllib.request -from urllib.parse import urlparse -from html.parser import HTMLParser -import docx, docx.table +import re +import docx from docx import Document -from docx.shared import RGBColor, Pt, Inches -from docx.enum.text import WD_COLOR, WD_ALIGN_PARAGRAPH -from docx.oxml import OxmlElement -from docx.oxml.ns import qn - from bs4 import BeautifulSoup +from html.parser import HTMLParser -# values in inches -INDENT = 0.25 -LIST_INDENT = 0.5 -MAX_INDENT = 5.5 # To stop indents going off the page -# Style to use with tables. By default no style is used. -DEFAULT_TABLE_STYLE = None +def get_table_rows(table_soup): + table_row_selectors = [ + 'table > tr', 'table > thead > tr', 'table > tbody > tr', + 'table > tfoot > tr' + ] + # If there's a header, body, footer or direct child tr tags, add row dimensions from there + return table_soup.select(', '.join(table_row_selectors), recursive=False) -# Style to use with paragraphs. By default no style is used. -DEFAULT_PARAGRAPH_STYLE = None +def get_table_columns(row): + # Get all columns for the specified row tag. + return row.find_all(['th', 'td'], recursive=False) if row else [] -def get_filename_from_url(url): - return os.path.basename(urlparse(url).path) -def is_url(url): - """ - Not to be used for actually validating a url, but in our use case we only - care if it's a url or a file path, and they're pretty distinguishable - """ - parts = urlparse(url) - return all([parts.scheme, parts.netloc, parts.path]) +def get_table_dimensions(table_soup): + # Get rows for the table + rows = get_table_rows(table_soup) + # Table is either empty or has non-direct children between table and tr tags + # Thus the row dimensions and column dimensions are assumed to be 0 -def fetch_image(url): - """ - Attempts to fetch an image from a url. - If successful returns a bytes object, else returns None - :return: - """ - try: - with urllib.request.urlopen(url) as response: - # security flaw? - return io.BytesIO(response.read()) - except urllib.error.URLError: - return None + cols = get_table_columns(rows[0]) if rows else [] + # Add colspan calculation column number + col_count = 0 + for col in cols: + colspan = col.attrs.get('colspan', 1) + col_count += int(colspan) + + return rows, col_count + + +def get_cell_html(soup): + # Returns string of td element with opening and closing tags removed + # Cannot use find_all as it only finds element tags and does not find text which + # is not inside an element + return ' '.join([str(i) for i in soup.contents]) + + +def delete_paragraph(paragraph): + # https://github.com/python-openxml/python-docx/issues/33#issuecomment-77661907 + p = paragraph._element + p.getparent().remove(p) + p._p = p._element = None -def remove_last_occurence(ls, x): - ls.pop(len(ls) - ls[::-1].index(x) - 1) def remove_whitespace(string, leading=False, trailing=False): """Remove white space from a string. @@ -122,11 +118,6 @@ def remove_whitespace(string, leading=False, trailing=False): # TODO need some way to get rid of extra spaces in e.g. text text return re.sub(r'\s+', ' ', string) -def delete_paragraph(paragraph): - # https://github.com/python-openxml/python-docx/issues/33#issuecomment-77661907 - p = paragraph._element - p.getparent().remove(p) - p._p = p._element = None font_styles = { 'b': 'bold', @@ -145,13 +136,8 @@ font_names = { 'pre': 'Courier', } -styles = { - 'LIST_BULLET': 'List Bullet', - 'LIST_NUMBER': 'List Number', -} class HtmlToDocx(HTMLParser): - def __init__(self): super().__init__() self.options = { @@ -161,13 +147,11 @@ class HtmlToDocx(HTMLParser): 'styles': True, } self.table_row_selectors = [ - 'table > tr', - 'table > thead > tr', - 'table > tbody > tr', + 'table > tr', 'table > thead > tr', 'table > tbody > tr', 'table > tfoot > tr' ] - self.table_style = DEFAULT_TABLE_STYLE - self.paragraph_style = DEFAULT_PARAGRAPH_STYLE + self.table_style = None + self.paragraph_style = None def set_initial_attrs(self, document=None): self.tags = { @@ -178,9 +162,10 @@ class HtmlToDocx(HTMLParser): self.doc = document else: self.doc = Document() - self.bs = self.options['fix-html'] # whether or not to clean with BeautifulSoup + self.bs = self.options[ + 'fix-html'] # whether or not to clean with BeautifulSoup self.document = self.doc - self.include_tables = True #TODO add this option back in? + self.include_tables = True #TODO add this option back in? self.include_images = self.options['images'] self.include_styles = self.options['styles'] self.paragraph = None @@ -193,55 +178,52 @@ class HtmlToDocx(HTMLParser): self.table_style = other.table_style self.paragraph_style = other.paragraph_style - def get_cell_html(self, soup): - # Returns string of td element with opening and closing tags removed - # Cannot use find_all as it only finds element tags and does not find text which - # is not inside an element - return ' '.join([str(i) for i in soup.contents]) - - def add_styles_to_paragraph(self, style): - if 'text-align' in style: - align = style['text-align'] - if align == 'center': - self.paragraph.paragraph_format.alignment = WD_ALIGN_PARAGRAPH.CENTER - elif align == 'right': - self.paragraph.paragraph_format.alignment = WD_ALIGN_PARAGRAPH.RIGHT - elif align == 'justify': - self.paragraph.paragraph_format.alignment = WD_ALIGN_PARAGRAPH.JUSTIFY - if 'margin-left' in style: - margin = style['margin-left'] - units = re.sub(r'[0-9]+', '', margin) - margin = int(float(re.sub(r'[a-z]+', '', margin))) - if units == 'px': - self.paragraph.paragraph_format.left_indent = Inches(min(margin // 10 * INDENT, MAX_INDENT)) - # TODO handle non px units - - def add_styles_to_run(self, style): - if 'color' in style: - if 'rgb' in style['color']: - color = re.sub(r'[a-z()]+', '', style['color']) - colors = [int(x) for x in color.split(',')] - elif '#' in style['color']: - color = style['color'].lstrip('#') - colors = tuple(int(color[i:i+2], 16) for i in (0, 2, 4)) - else: - colors = [0, 0, 0] - # TODO map colors to named colors (and extended colors...) - # For now set color to black to prevent crashing - self.run.font.color.rgb = RGBColor(*colors) - - if 'background-color' in style: - if 'rgb' in style['background-color']: - color = color = re.sub(r'[a-z()]+', '', style['background-color']) - colors = [int(x) for x in color.split(',')] - elif '#' in style['background-color']: - color = style['background-color'].lstrip('#') - colors = tuple(int(color[i:i+2], 16) for i in (0, 2, 4)) - else: - colors = [0, 0, 0] - # TODO map colors to named colors (and extended colors...) - # For now set color to black to prevent crashing - self.run.font.highlight_color = WD_COLOR.GRAY_25 #TODO: map colors + def ignore_nested_tables(self, tables_soup): + """ + Returns array containing only the highest level tables + Operates on the assumption that bs4 returns child elements immediately after + the parent element in `find_all`. If this changes in the future, this method will need to be updated + :return: + """ + new_tables = [] + nest = 0 + for table in tables_soup: + if nest: + nest -= 1 + continue + new_tables.append(table) + nest = len(table.find_all('table')) + return new_tables + + def get_tables(self): + if not hasattr(self, 'soup'): + self.include_tables = False + return + # find other way to do it, or require this dependency? + self.tables = self.ignore_nested_tables(self.soup.find_all('table')) + self.table_no = 0 + + def run_process(self, html): + if self.bs and BeautifulSoup: + self.soup = BeautifulSoup(html, 'html.parser') + html = str(self.soup) + if self.include_tables: + self.get_tables() + self.feed(html) + + def add_html_to_cell(self, html, cell): + if not isinstance(cell, docx.table._Cell): + raise ValueError('Second argument needs to be a %s' % + docx.table._Cell) + unwanted_paragraph = cell.paragraphs[0] + if unwanted_paragraph.text == "": + delete_paragraph(unwanted_paragraph) + self.set_initial_attrs(cell) + self.run_process(html) + # cells must end with a paragraph or will get message about corrupt file + # https://stackoverflow.com/a/29287121 + if not self.doc.paragraphs: + self.doc.add_paragraph('') def apply_paragraph_style(self, style=None): try: @@ -250,69 +232,10 @@ class HtmlToDocx(HTMLParser): elif self.paragraph_style: self.paragraph.style = self.paragraph_style except KeyError as e: - raise ValueError(f"Unable to apply style {self.paragraph_style}.") from e - - def parse_dict_string(self, string, separator=';'): - new_string = string.replace(" ", '').split(separator) - string_dict = dict([x.split(':') for x in new_string if ':' in x]) - return string_dict - - def handle_li(self): - # check list stack to determine style and depth - list_depth = len(self.tags['list']) - if list_depth: - list_type = self.tags['list'][-1] - else: - list_type = 'ul' # assign unordered if no tag + raise ValueError( + f"Unable to apply style {self.paragraph_style}.") from e - if list_type == 'ol': - list_style = styles['LIST_NUMBER'] - else: - list_style = styles['LIST_BULLET'] - - self.paragraph = self.doc.add_paragraph(style=list_style) - self.paragraph.paragraph_format.left_indent = Inches(min(list_depth * LIST_INDENT, MAX_INDENT)) - self.paragraph.paragraph_format.line_spacing = 1 - - def add_image_to_cell(self, cell, image): - # python-docx doesn't have method yet for adding images to table cells. For now we use this - paragraph = cell.add_paragraph() - run = paragraph.add_run() - run.add_picture(image) - - def handle_img(self, current_attrs): - if not self.include_images: - self.skip = True - self.skip_tag = 'img' - return - src = current_attrs['src'] - # fetch image - src_is_url = is_url(src) - if src_is_url: - try: - image = fetch_image(src) - except urllib.error.URLError: - image = None - else: - image = src - # add image to doc - if image: - try: - if isinstance(self.doc, docx.document.Document): - self.doc.add_picture(image) - else: - self.add_image_to_cell(self.doc, image) - except FileNotFoundError: - image = None - if not image: - if src_is_url: - self.doc.add_paragraph("" % src) - else: - # avoid exposing filepaths in document - self.doc.add_paragraph("" % get_filename_from_url(src)) - - - def handle_table(self, html): + def handle_table(self, html, doc): """ To handle nested tables, we will parse tables manually as follows: Get table soup @@ -320,194 +243,42 @@ class HtmlToDocx(HTMLParser): Iterate over soup and fill docx table with new instances of this parser Tell HTMLParser to ignore any tags until the corresponding closing table tag """ - doc = Document() table_soup = BeautifulSoup(html, 'html.parser') - rows, cols_len = self.get_table_dimensions(table_soup) + rows, cols_len = get_table_dimensions(table_soup) table = doc.add_table(len(rows), cols_len) table.style = doc.styles['Table Grid'] + cell_row = 0 for index, row in enumerate(rows): - cols = self.get_table_columns(row) + cols = get_table_columns(row) cell_col = 0 for col in cols: colspan = int(col.attrs.get('colspan', 1)) rowspan = int(col.attrs.get('rowspan', 1)) - cell_html = self.get_cell_html(col) - + cell_html = get_cell_html(col) if col.name == 'th': cell_html = "%s" % cell_html + docx_cell = table.cell(cell_row, cell_col) + while docx_cell.text != '': # Skip the merged cell cell_col += 1 docx_cell = table.cell(cell_row, cell_col) - cell_to_merge = table.cell(cell_row + rowspan - 1, cell_col + colspan - 1) + cell_to_merge = table.cell(cell_row + rowspan - 1, + cell_col + colspan - 1) if docx_cell != cell_to_merge: docx_cell.merge(cell_to_merge) child_parser = HtmlToDocx() child_parser.copy_settings_from(self) - - child_parser.add_html_to_cell(cell_html or ' ', docx_cell) # occupy the position + child_parser.add_html_to_cell(cell_html or ' ', docx_cell) cell_col += colspan cell_row += 1 - - # skip all tags until corresponding closing tag - self.instances_to_skip = len(table_soup.find_all('table')) - self.skip_tag = 'table' - self.skip = True - self.table = None - return table - - def handle_link(self, href, text): - # Link requires a relationship - is_external = href.startswith('http') - rel_id = self.paragraph.part.relate_to( - href, - docx.opc.constants.RELATIONSHIP_TYPE.HYPERLINK, - is_external=True # don't support anchor links for this library yet - ) - - # Create the w:hyperlink tag and add needed values - hyperlink = docx.oxml.shared.OxmlElement('w:hyperlink') - hyperlink.set(docx.oxml.shared.qn('r:id'), rel_id) - - - # Create sub-run - subrun = self.paragraph.add_run() - rPr = docx.oxml.shared.OxmlElement('w:rPr') - - # add default color - c = docx.oxml.shared.OxmlElement('w:color') - c.set(docx.oxml.shared.qn('w:val'), "0000EE") - rPr.append(c) - - # add underline - u = docx.oxml.shared.OxmlElement('w:u') - u.set(docx.oxml.shared.qn('w:val'), 'single') - rPr.append(u) - - subrun._r.append(rPr) - subrun._r.text = text - - # Add subrun to hyperlink - hyperlink.append(subrun._r) - - # Add hyperlink to run - self.paragraph._p.append(hyperlink) - - def handle_starttag(self, tag, attrs): - if self.skip: - return - if tag == 'head': - self.skip = True - self.skip_tag = tag - self.instances_to_skip = 0 - return - elif tag == 'body': - return - - current_attrs = dict(attrs) - - if tag == 'span': - self.tags['span'].append(current_attrs) - return - elif tag == 'ol' or tag == 'ul': - self.tags['list'].append(tag) - return # don't apply styles for now - elif tag == 'br': - self.run.add_break() - return - - self.tags[tag] = current_attrs - if tag in ['p', 'pre']: - self.paragraph = self.doc.add_paragraph() - self.apply_paragraph_style() - - elif tag == 'li': - self.handle_li() - - elif tag == "hr": - - # This implementation was taken from: - # https://github.com/python-openxml/python-docx/issues/105#issuecomment-62806373 - - self.paragraph = self.doc.add_paragraph() - pPr = self.paragraph._p.get_or_add_pPr() - pBdr = OxmlElement('w:pBdr') - pPr.insert_element_before(pBdr, - 'w:shd', 'w:tabs', 'w:suppressAutoHyphens', 'w:kinsoku', 'w:wordWrap', - 'w:overflowPunct', 'w:topLinePunct', 'w:autoSpaceDE', 'w:autoSpaceDN', - 'w:bidi', 'w:adjustRightInd', 'w:snapToGrid', 'w:spacing', 'w:ind', - 'w:contextualSpacing', 'w:mirrorIndents', 'w:suppressOverlap', 'w:jc', - 'w:textDirection', 'w:textAlignment', 'w:textboxTightWrap', - 'w:outlineLvl', 'w:divId', 'w:cnfStyle', 'w:rPr', 'w:sectPr', - 'w:pPrChange' - ) - bottom = OxmlElement('w:bottom') - bottom.set(qn('w:val'), 'single') - bottom.set(qn('w:sz'), '6') - bottom.set(qn('w:space'), '1') - bottom.set(qn('w:color'), 'auto') - pBdr.append(bottom) - - elif re.match('h[1-9]', tag): - if isinstance(self.doc, docx.document.Document): - h_size = int(tag[1]) - self.paragraph = self.doc.add_heading(level=min(h_size, 9)) - else: - self.paragraph = self.doc.add_paragraph() - - elif tag == 'img': - self.handle_img(current_attrs) - return - - elif tag == 'table': - self.handle_table() - return - # set new run reference point in case of leading line breaks - if tag in ['p', 'li', 'pre']: - self.run = self.paragraph.add_run() - - # add style - if not self.include_styles: - return - if 'style' in current_attrs and self.paragraph: - style = self.parse_dict_string(current_attrs['style']) - self.add_styles_to_paragraph(style) - - def handle_endtag(self, tag): - if self.skip: - if not tag == self.skip_tag: - return - - if self.instances_to_skip > 0: - self.instances_to_skip -= 1 - return - - self.skip = False - self.skip_tag = None - self.paragraph = None - - if tag == 'span': - if self.tags['span']: - self.tags['span'].pop() - return - elif tag == 'ol' or tag == 'ul': - remove_last_occurence(self.tags['list'], tag) - return - elif tag == 'table': - self.table_no += 1 - self.table = None - self.doc = self.document - self.paragraph = None - - if tag in self.tags: - self.tags.pop(tag) - # maybe set relevant reference to None? + doc.save('1.docx') def handle_data(self, data): if self.skip: @@ -546,87 +317,3 @@ class HtmlToDocx(HTMLParser): if tag in font_names: font_name = font_names[tag] self.run.font.name = font_name - - def ignore_nested_tables(self, tables_soup): - """ - Returns array containing only the highest level tables - Operates on the assumption that bs4 returns child elements immediately after - the parent element in `find_all`. If this changes in the future, this method will need to be updated - :return: - """ - new_tables = [] - nest = 0 - for table in tables_soup: - if nest: - nest -= 1 - continue - new_tables.append(table) - nest = len(table.find_all('table')) - return new_tables - - def get_table_rows(self, table_soup): - # If there's a header, body, footer or direct child tr tags, add row dimensions from there - return table_soup.select(', '.join(self.table_row_selectors), recursive=False) - - def get_table_columns(self, row): - # Get all columns for the specified row tag. - return row.find_all(['th', 'td'], recursive=False) if row else [] - - def get_table_dimensions(self, table_soup): - # Get rows for the table - rows = self.get_table_rows(table_soup) - # Table is either empty or has non-direct children between table and tr tags - # Thus the row dimensions and column dimensions are assumed to be 0 - - cols = self.get_table_columns(rows[0]) if rows else [] - # Add colspan calculation column number - col_count = 0 - for col in cols: - colspan = col.attrs.get('colspan', 1) - col_count += int(colspan) - - # return len(rows), col_count - return rows, col_count - - def get_tables(self): - if not hasattr(self, 'soup'): - self.include_tables = False - return - # find other way to do it, or require this dependency? - self.tables = self.ignore_nested_tables(self.soup.find_all('table')) - self.table_no = 0 - - def run_process(self, html): - if self.bs and BeautifulSoup: - self.soup = BeautifulSoup(html, 'html.parser') - html = str(self.soup) - if self.include_tables: - self.get_tables() - self.feed(html) - - def add_html_to_document(self, html, document): - if not isinstance(html, str): - raise ValueError('First argument needs to be a %s' % str) - elif not isinstance(document, docx.document.Document) and not isinstance(document, docx.table._Cell): - raise ValueError('Second argument needs to be a %s' % docx.document.Document) - self.set_initial_attrs(document) - self.run_process(html) - - def add_html_to_cell(self, html, cell): - self.set_initial_attrs(cell) - self.run_process(html) - - def parse_html_file(self, filename_html, filename_docx=None): - with open(filename_html, 'r') as infile: - html = infile.read() - self.set_initial_attrs() - self.run_process(html) - if not filename_docx: - path, filename = os.path.split(filename_html) - filename_docx = '%s/new_docx_file_%s' % (path, filename) - self.doc.save('%s.docx' % filename_docx) - - def parse_html_string(self, html): - self.set_initial_attrs() - self.run_process(html) - return self.doc \ No newline at end of file diff --git a/ppstructure/table/README.md b/ppstructure/table/README.md index 08635516ba8301e6f98f175e5eba8c0a97b1708e..1d082f3878c56e42d175d13c75e1fe17916e7781 100644 --- a/ppstructure/table/README.md +++ b/ppstructure/table/README.md @@ -114,7 +114,7 @@ python3 table/eval_table.py \ --det_model_dir=path/to/det_model_dir \ --rec_model_dir=path/to/rec_model_dir \ --table_model_dir=path/to/table_model_dir \ - --image_dir=../doc/table/1.png \ + --image_dir=docs/table/table.jpg \ --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt \ --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \ --det_limit_side_len=736 \ @@ -145,6 +145,7 @@ python3 table/eval_table.py \ --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \ --det_limit_side_len=736 \ --det_limit_type=min \ + --rec_image_shape=3,32,320 \ --gt_path=path/to/gt.txt ``` diff --git a/ppstructure/table/README_ch.md b/ppstructure/table/README_ch.md index 1ef126261d9ce832cd1919a1b3991f341add998c..feccb70adfe20fa8c1cd06f33a10ee6fa043e69e 100644 --- a/ppstructure/table/README_ch.md +++ b/ppstructure/table/README_ch.md @@ -118,7 +118,7 @@ python3 table/eval_table.py \ --det_model_dir=path/to/det_model_dir \ --rec_model_dir=path/to/rec_model_dir \ --table_model_dir=path/to/table_model_dir \ - --image_dir=../doc/table/1.png \ + --image_dir=docs/table/table.jpg \ --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt \ --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \ --det_limit_side_len=736 \ @@ -149,6 +149,7 @@ python3 table/eval_table.py \ --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \ --det_limit_side_len=736 \ --det_limit_type=min \ + --rec_image_shape=3,32,320 \ --gt_path=path/to/gt.txt ``` diff --git a/ppstructure/table/predict_structure.py b/ppstructure/table/predict_structure.py index 45cbba3e298004d3711b05e6fb7cffecae637601..0bf100852b9e9d501dfc858d8ce0787da42a61ed 100755 --- a/ppstructure/table/predict_structure.py +++ b/ppstructure/table/predict_structure.py @@ -68,6 +68,7 @@ def build_pre_process_list(args): class TableStructurer(object): def __init__(self, args): + self.use_onnx = args.use_onnx pre_process_list = build_pre_process_list(args) if args.table_algorithm not in ['TableMaster']: postprocess_params = { @@ -98,13 +99,17 @@ class TableStructurer(object): return None, 0 img = np.expand_dims(img, axis=0) img = img.copy() - - self.input_tensor.copy_from_cpu(img) - self.predictor.run() - outputs = [] - for output_tensor in self.output_tensors: - output = output_tensor.copy_to_cpu() - outputs.append(output) + if self.use_onnx: + input_dict = {} + input_dict[self.input_tensor.name] = img + outputs = self.predictor.run(self.output_tensors, input_dict) + else: + self.input_tensor.copy_from_cpu(img) + self.predictor.run() + outputs = [] + for output_tensor in self.output_tensors: + output = output_tensor.copy_to_cpu() + outputs.append(output) preds = {} preds['structure_probs'] = outputs[1] diff --git a/ppstructure/utility.py b/ppstructure/utility.py index bdea0af69e37e15d1f191b2a86c036ae1c2b1e45..59b58edb4b0c9c5992981073b12e419fe1cc84d6 100644 --- a/ppstructure/utility.py +++ b/ppstructure/utility.py @@ -32,7 +32,7 @@ def init_args(): parser.add_argument( "--table_char_dict_path", type=str, - default="../ppocr/utils/dict/table_structure_dict.txt") + default="../ppocr/utils/dict/table_structure_dict_ch.txt") # params for layout parser.add_argument("--layout_model_dir", type=str) parser.add_argument( @@ -52,6 +52,8 @@ def init_args(): # params for kie parser.add_argument("--kie_algorithm", type=str, default='LayoutXLM') parser.add_argument("--ser_model_dir", type=str) + parser.add_argument("--re_model_dir", type=str) + parser.add_argument("--use_visual_backbone", type=str2bool, default=True) parser.add_argument( "--ser_dict_path", type=str, @@ -90,11 +92,6 @@ def init_args(): type=str2bool, default=False, help='Whether to enable layout of recovery') - parser.add_argument( - "--save_pdf", - type=str2bool, - default=False, - help='Whether to save pdf file') return parser @@ -108,7 +105,38 @@ def draw_structure_result(image, result, font_path): if isinstance(image, np.ndarray): image = Image.fromarray(image) boxes, txts, scores = [], [], [] + + img_layout = image.copy() + draw_layout = ImageDraw.Draw(img_layout) + text_color = (255, 255, 255) + text_background_color = (80, 127, 255) + catid2color = {} + font_size = 15 + font = ImageFont.truetype(font_path, font_size, encoding="utf-8") + for region in result: + if region['type'] not in catid2color: + box_color = (random.randint(0, 255), random.randint(0, 255), + random.randint(0, 255)) + catid2color[region['type']] = box_color + else: + box_color = catid2color[region['type']] + box_layout = region['bbox'] + draw_layout.rectangle( + [(box_layout[0], box_layout[1]), (box_layout[2], box_layout[3])], + outline=box_color, + width=3) + text_w, text_h = font.getsize(region['type']) + draw_layout.rectangle( + [(box_layout[0], box_layout[1]), + (box_layout[0] + text_w, box_layout[1] + text_h)], + fill=text_background_color) + draw_layout.text( + (box_layout[0], box_layout[1]), + region['type'], + fill=text_color, + font=font) + if region['type'] == 'table': pass else: @@ -116,6 +144,7 @@ def draw_structure_result(image, result, font_path): boxes.append(np.array(text_result['text_region'])) txts.append(text_result['text']) scores.append(text_result['confidence']) + im_show = draw_ocr_box_txt( - image, boxes, txts, scores, font_path=font_path, drop_score=0) + img_layout, boxes, txts, scores, font_path=font_path, drop_score=0) return im_show diff --git a/requirements.txt b/requirements.txt index 2c0741a065dacf1fb637865e8f9796a611876d60..7a018b50952a876b4839eabbd72fac09d2bbd73b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,3 +14,5 @@ lxml premailer openpyxl attrdict +Polygon3 +PyMuPDF==1.18.7 diff --git a/test_tipc/benchmark_train.sh b/test_tipc/benchmark_train.sh index 1dcb0129e767e6c35adfad36aa5dce2fbd84a2fd..25fda8f97f0bfdefbd6922b13a0ffef3f40c3de9 100644 --- a/test_tipc/benchmark_train.sh +++ b/test_tipc/benchmark_train.sh @@ -1,12 +1,6 @@ #!/bin/bash source test_tipc/common_func.sh -# set env -python=python -export str_tmp=$(echo `pip list|grep paddlepaddle-gpu|awk -F ' ' '{print $2}'`) -export frame_version=${str_tmp%%.post*} -export frame_commit=$(echo `${python} -c "import paddle;print(paddle.version.commit)"`) - # run benchmark sh # Usage: # bash run_benchmark_train.sh config.txt params @@ -86,6 +80,13 @@ dataline=`cat $FILENAME` IFS=$'\n' lines=(${dataline}) model_name=$(func_parser_value "${lines[1]}") +python_name=$(func_parser_value "${lines[2]}") + +# set env +python=${python_name} +export str_tmp=$(echo `pip list|grep paddlepaddle-gpu|awk -F ' ' '{print $2}'`) +export frame_version=${str_tmp%%.post*} +export frame_commit=$(echo `${python} -c "import paddle;print(paddle.version.commit)"`) # 获取benchmark_params所在的行数 line_num=`grep -n "train_benchmark_params" $FILENAME | cut -d ":" -f 1` diff --git a/test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_infer_python.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_infer_python.txt index f3aa9d0f8218a24b11e3d0d079ae79a07d3e5874..4112e6498c6316e211ad69a69bdb531ec7a105b2 100644 --- a/test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_infer_python.txt +++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_infer_python.txt @@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/ null:null ## trainer:norm_train -norm_train:tools/train.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained Global.print_batch_step=1 Train.loader.shuffle=false +norm_train:tools/train.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained Global.print_batch_step=2 Train.loader.shuffle=false pact_train:null fpgm_train:null distill_train:null diff --git a/test_tipc/configs/det_r18_ct/train_infer_python.txt b/test_tipc/configs/det_r18_ct/train_infer_python.txt new file mode 100644 index 0000000000000000000000000000000000000000..5933fdbeed762a73324fbfb5a4113a390926e7ea --- /dev/null +++ b/test_tipc/configs/det_r18_ct/train_infer_python.txt @@ -0,0 +1,53 @@ +===========================train_params=========================== +model_name:det_r18_ct +python:python3.7 +gpu_list:0|0,1 +Global.use_gpu:True|True +Global.auto_cast:null +Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300 +Global.save_model_dir:./output/ +Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_lite_infer=4 +Global.pretrained_model:null +train_model_name:latest +train_infer_img_dir:./train_data/total_text/test/rgb/ +null:null +## +trainer:norm_train +norm_train:tools/train.py -c configs/det/det_r18_vd_ct.yml -o Global.print_batch_step=1 Train.loader.shuffle=false +quant_export:null +fpgm_export:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:tools/eval.py -c configs/det/det_r18_vd_ct.yml -o +null:null +## +===========================infer_params=========================== +Global.save_inference_dir:./output/ +Global.checkpoints: +norm_export:tools/export_model.py -c configs/det/det_r18_vd_ct.yml -o +quant_export:null +fpgm_export:null +distill_export:null +export1:null +export2:null +## +train_model:./inference/det_r18_vd_ct/best_accuracy +infer_export:tools/export_model.py -c configs/det/det_r18_vd_ct.yml -o +infer_quant:False +inference:tools/infer/predict_det.py +--use_gpu:True|False +--enable_mkldnn:False +--cpu_threads:6 +--rec_batch_num:1 +--use_tensorrt:False +--precision:fp32 +--det_model_dir: +--image_dir:./inference/ch_det_data_50/all-sum-510/ +--save_log_path:null +--benchmark:True +null:null +===========================infer_benchmark_params========================== +random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}] \ No newline at end of file diff --git a/test_tipc/configs/en_table_structure/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt b/test_tipc/configs/en_table_structure/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt new file mode 100644 index 0000000000000000000000000000000000000000..ad002a334e3b351b0fa2aa641906f4aa753071c9 --- /dev/null +++ b/test_tipc/configs/en_table_structure/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt @@ -0,0 +1,20 @@ +===========================cpp_infer_params=========================== +model_name:en_table_structure +use_opencv:True +infer_model:./inference/en_ppocr_mobile_v2.0_table_structure_infer/ +infer_quant:False +inference:./deploy/cpp_infer/build/ppocr --rec_img_h=32 --det_model_dir=./inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=./inference/en_ppocr_mobile_v2.0_table_rec_infer --rec_char_dict_path=./ppocr/utils/dict/table_dict.txt --table_char_dict_path=./ppocr/utils/dict/table_structure_dict.txt --limit_side_len=736 --limit_type=min --output=./output/table --merge_no_span_structure=False --type=structure --table=True +--use_gpu:True|False +--enable_mkldnn:False +--cpu_threads:6 +--rec_batch_num:6 +--use_tensorrt:False +--precision:fp32 +--table_model_dir: +--image_dir:./ppstructure/docs/table/table.jpg +null:null +--benchmark:True +--det:True +--rec:True +--cls:False +--use_angle_cls:False \ No newline at end of file diff --git a/test_tipc/configs/en_table_structure/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt b/test_tipc/configs/en_table_structure/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt new file mode 100644 index 0000000000000000000000000000000000000000..068c4c6b1d2655b9dcda1120425de7d52d0d543d --- /dev/null +++ b/test_tipc/configs/en_table_structure/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt @@ -0,0 +1,17 @@ +===========================paddle2onnx_params=========================== +model_name:en_table_structure +python:python3.7 +2onnx: paddle2onnx +--det_model_dir:./inference/en_ppocr_mobile_v2.0_table_structure_infer/ +--model_filename:inference.pdmodel +--params_filename:inference.pdiparams +--det_save_file:./inference/en_ppocr_mobile_v2.0_table_structure_infer/model.onnx +--rec_model_dir: +--rec_save_file: +--opset_version:10 +--enable_onnx_checker:True +inference:ppstructure/table/predict_structure.py --table_char_dict_path=./ppocr/utils/dict/table_structure_dict.txt +--use_gpu:True|False +--det_model_dir: +--rec_model_dir: +--image_dir:./ppstructure/docs/table/table.jpg \ No newline at end of file diff --git a/test_tipc/configs/en_table_structure_PACT/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt b/test_tipc/configs/en_table_structure_PACT/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt new file mode 100644 index 0000000000000000000000000000000000000000..22f77c469bcf7faeb40f5018116e29f67feeadf2 --- /dev/null +++ b/test_tipc/configs/en_table_structure_PACT/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt @@ -0,0 +1,20 @@ +===========================cpp_infer_params=========================== +model_name:en_table_structure_PACT +use_opencv:True +infer_model:./inference/en_ppocr_mobile_v2.0_table_structure_slim_infer/ +infer_quant:False +inference:./deploy/cpp_infer/build/ppocr --rec_img_h=32 --det_model_dir=./inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=./inference/en_ppocr_mobile_v2.0_table_rec_infer --rec_char_dict_path=./ppocr/utils/dict/table_dict.txt --table_char_dict_path=./ppocr/utils/dict/table_structure_dict.txt --limit_side_len=736 --limit_type=min --output=./output/table --merge_no_span_structure=False --type=structure --table=True +--use_gpu:True|False +--enable_mkldnn:False +--cpu_threads:6 +--rec_batch_num:6 +--use_tensorrt:False +--precision:fp32 +--table_model_dir: +--image_dir:./ppstructure/docs/table/table.jpg +null:null +--benchmark:True +--det:True +--rec:True +--cls:False +--use_angle_cls:False \ No newline at end of file diff --git a/test_tipc/configs/layoutxlm_ser/ser_layoutxlm_xfund_zh.yml b/test_tipc/configs/layoutxlm_ser/ser_layoutxlm_xfund_zh.yml new file mode 100644 index 0000000000000000000000000000000000000000..d2be152f0bae7d87129904d87c56c6d777a1f338 --- /dev/null +++ b/test_tipc/configs/layoutxlm_ser/ser_layoutxlm_xfund_zh.yml @@ -0,0 +1,122 @@ +Global: + use_gpu: True + epoch_num: &epoch_num 200 + log_smooth_window: 10 + print_batch_step: 10 + save_model_dir: ./output/ser_layoutxlm_xfund_zh + save_epoch_step: 2000 + # evaluation is run every 10 iterations after the 0th iteration + eval_batch_step: [ 0, 187 ] + cal_metric_during_train: False + save_inference_dir: + use_visualdl: False + seed: 2022 + infer_img: ppstructure/docs/kie/input/zh_val_42.jpg + save_res_path: ./output/ser_layoutxlm_xfund_zh/res + +Architecture: + model_type: kie + algorithm: &algorithm "LayoutXLM" + Transform: + Backbone: + name: LayoutXLMForSer + pretrained: True + checkpoints: + num_classes: &num_classes 7 + +Loss: + name: VQASerTokenLayoutLMLoss + num_classes: *num_classes + key: "backbone_out" + +Optimizer: + name: AdamW + beta1: 0.9 + beta2: 0.999 + lr: + name: Linear + learning_rate: 0.00005 + epochs: *epoch_num + warmup_epoch: 2 + regularizer: + name: L2 + factor: 0.00000 + +PostProcess: + name: VQASerTokenLayoutLMPostProcess + class_path: &class_path train_data/XFUND/class_list_xfun.txt + +Metric: + name: VQASerTokenMetric + main_indicator: hmean + +Train: + dataset: + name: SimpleDataSet + data_dir: train_data/XFUND/zh_train/image + label_file_list: + - train_data/XFUND/zh_train/train.json + ratio_list: [ 1.0 ] + transforms: + - DecodeImage: # load image + img_mode: RGB + channel_first: False + - VQATokenLabelEncode: # Class handling label + contains_re: False + algorithm: *algorithm + class_path: *class_path + - VQATokenPad: + max_seq_len: &max_seq_len 512 + return_attention_mask: True + - VQASerTokenChunk: + max_seq_len: *max_seq_len + - Resize: + size: [224,224] + - NormalizeImage: + scale: 1 + mean: [ 123.675, 116.28, 103.53 ] + std: [ 58.395, 57.12, 57.375 ] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order + loader: + shuffle: True + drop_last: False + batch_size_per_card: 8 + num_workers: 4 + +Eval: + dataset: + name: SimpleDataSet + data_dir: train_data/XFUND/zh_val/image + label_file_list: + - train_data/XFUND/zh_val/val.json + transforms: + - DecodeImage: # load image + img_mode: RGB + channel_first: False + - VQATokenLabelEncode: # Class handling label + contains_re: False + algorithm: *algorithm + class_path: *class_path + - VQATokenPad: + max_seq_len: *max_seq_len + return_attention_mask: True + - VQASerTokenChunk: + max_seq_len: *max_seq_len + - Resize: + size: [224,224] + - NormalizeImage: + scale: 1 + mean: [ 123.675, 116.28, 103.53 ] + std: [ 58.395, 57.12, 57.375 ] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order + loader: + shuffle: False + drop_last: False + batch_size_per_card: 8 + num_workers: 4 diff --git a/test_tipc/configs/layoutxlm_ser/train_infer_python.txt b/test_tipc/configs/layoutxlm_ser/train_infer_python.txt index 549a31e69e367237ec0396778162a5f91c8b7412..d07daa9a1429ec5cd1955ec64ded122a9d1a723d 100644 --- a/test_tipc/configs/layoutxlm_ser/train_infer_python.txt +++ b/test_tipc/configs/layoutxlm_ser/train_infer_python.txt @@ -13,7 +13,7 @@ train_infer_img_dir:ppstructure/docs/kie/input/zh_val_42.jpg null:null ## trainer:norm_train -norm_train:tools/train.py -c configs/kie/layoutlm_series/ser_layoutlm_xfund_zh.yml -o Global.print_batch_step=1 Global.eval_batch_step=[1000,1000] Train.loader.shuffle=false +norm_train:tools/train.py -c test_tipc/configs/layoutxlm_ser/ser_layoutxlm_xfund_zh.yml -o Global.print_batch_step=1 Global.eval_batch_step=[1000,1000] Train.loader.shuffle=false pact_train:null fpgm_train:null distill_train:null @@ -27,7 +27,7 @@ null:null ===========================infer_params=========================== Global.save_inference_dir:./output/ Architecture.Backbone.checkpoints: -norm_export:tools/export_model.py -c configs/kie/layoutlm_series/ser_layoutlm_xfund_zh.yml -o +norm_export:tools/export_model.py -c test_tipc/configs/layoutxlm_ser/ser_layoutxlm_xfund_zh.yml -o quant_export: fpgm_export: distill_export:null diff --git a/test_tipc/configs/layoutxlm_ser/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/layoutxlm_ser/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt new file mode 100644 index 0000000000000000000000000000000000000000..96b43ceda84376f4d27e134d245229922d667e7e --- /dev/null +++ b/test_tipc/configs/layoutxlm_ser/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt @@ -0,0 +1,53 @@ +===========================train_params=========================== +model_name:layoutxlm_ser +python:python3.7 +gpu_list:192.168.0.1,192.168.0.2;0,1 +Global.use_gpu:True +Global.auto_cast:fp32 +Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=17 +Global.save_model_dir:./output/ +Train.loader.batch_size_per_card:lite_train_lite_infer=4|whole_train_whole_infer=8 +Architecture.Backbone.checkpoints:null +train_model_name:latest +train_infer_img_dir:ppstructure/docs/kie/input/zh_val_42.jpg +null:null +## +trainer:norm_train +norm_train:tools/train.py -c test_tipc/configs/layoutxlm_ser/ser_layoutxlm_xfund_zh.yml -o +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:null +null:null +## +===========================infer_params=========================== +Global.save_inference_dir:./output/ +Architecture.Backbone.checkpoints: +norm_export:tools/export_model.py -c test_tipc/configs/layoutxlm_ser/ser_layoutxlm_xfund_zh.yml -o +quant_export: +fpgm_export: +distill_export:null +export1:null +export2:null +## +infer_model:null +infer_export:null +infer_quant:False +inference:ppstructure/kie/predict_kie_token_ser.py --kie_algorithm=LayoutXLM --ser_dict_path=train_data/XFUND/class_list_xfun.txt --output=output +--use_gpu:False +--enable_mkldnn:False +--cpu_threads:6 +--rec_batch_num:1 +--use_tensorrt:False +--precision:fp32 +--ser_model_dir: +--image_dir:./ppstructure/docs/kie/input/zh_val_42.jpg +null:null +--benchmark:False +null:null +===========================infer_benchmark_params========================== +random_infer_input:[{float32,[3,224,224]}] diff --git a/test_tipc/configs/layoutxlm_ser/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/layoutxlm_ser/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt new file mode 100644 index 0000000000000000000000000000000000000000..835395784022f4fcefbfff084dcef8a7bc2a146d --- /dev/null +++ b/test_tipc/configs/layoutxlm_ser/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt @@ -0,0 +1,53 @@ +===========================train_params=========================== +model_name:layoutxlm_ser +python:python3.7 +gpu_list:0|0,1 +Global.use_gpu:True|True +Global.auto_cast:amp +Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=17 +Global.save_model_dir:./output/ +Train.loader.batch_size_per_card:lite_train_lite_infer=4|whole_train_whole_infer=8 +Architecture.Backbone.checkpoints:null +train_model_name:latest +train_infer_img_dir:ppstructure/docs/kie/input/zh_val_42.jpg +null:null +## +trainer:norm_train +norm_train:tools/train.py -c test_tipc/configs/layoutxlm_ser/ser_layoutxlm_xfund_zh.yml -o Global.print_batch_step=1 Global.eval_batch_step=[1000,1000] Train.loader.shuffle=false +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:null +null:null +## +===========================infer_params=========================== +Global.save_inference_dir:./output/ +Architecture.Backbone.checkpoints: +norm_export:tools/export_model.py -c test_tipc/configs/layoutxlm_ser/ser_layoutxlm_xfund_zh.yml -o +quant_export: +fpgm_export: +distill_export:null +export1:null +export2:null +## +infer_model:null +infer_export:null +infer_quant:False +inference:ppstructure/kie/predict_kie_token_ser.py --kie_algorithm=LayoutXLM --ser_dict_path=train_data/XFUND/class_list_xfun.txt --output=output +--use_gpu:True|False +--enable_mkldnn:False +--cpu_threads:6 +--rec_batch_num:1 +--use_tensorrt:False +--precision:fp32 +--ser_model_dir: +--image_dir:./ppstructure/docs/kie/input/zh_val_42.jpg +null:null +--benchmark:False +null:null +===========================infer_benchmark_params========================== +random_infer_input:[{float32,[3,224,224]}] diff --git a/test_tipc/configs/layoutxlm_ser/train_pact_infer_python.txt b/test_tipc/configs/layoutxlm_ser/train_pact_infer_python.txt new file mode 100644 index 0000000000000000000000000000000000000000..fbf2a880269fba4596908def0980cb778a9281e3 --- /dev/null +++ b/test_tipc/configs/layoutxlm_ser/train_pact_infer_python.txt @@ -0,0 +1,53 @@ +===========================train_params=========================== +model_name:layoutxlm_ser_PACT +python:python3.7 +gpu_list:0|0,1 +Global.use_gpu:True|True +Global.auto_cast:fp32 +Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=17 +Global.save_model_dir:./output/ +Train.loader.batch_size_per_card:lite_train_lite_infer=4|whole_train_whole_infer=8 +Architecture.Backbone.checkpoints:pretrain_models/ser_LayoutXLM_xfun_zh +train_model_name:latest +train_infer_img_dir:ppstructure/docs/kie/input/zh_val_42.jpg +null:null +## +trainer:pact_train +norm_train:null +pact_train:deploy/slim/quantization/quant.py -c test_tipc/configs/layoutxlm_ser/ser_layoutxlm_xfund_zh.yml -o +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:null +null:null +## +===========================infer_params=========================== +Global.save_inference_dir:./output/ +Architecture.Backbone.checkpoints: +norm_export:null +quant_export:deploy/slim/quantization/export_model.py -c test_tipc/configs/layoutxlm_ser/ser_layoutxlm_xfund_zh.yml -o +fpgm_export: null +distill_export:null +export1:null +export2:null +## +infer_model:null +infer_export:null +infer_quant:False +inference:ppstructure/kie/predict_kie_token_ser.py --kie_algorithm=LayoutXLM --ser_dict_path=train_data/XFUND/class_list_xfun.txt --output=output +--use_gpu:True|False +--enable_mkldnn:False +--cpu_threads:6 +--rec_batch_num:1 +--use_tensorrt:False +--precision:fp32 +--ser_model_dir: +--image_dir:./ppstructure/docs/kie/input/zh_val_42.jpg +null:null +--benchmark:False +null:null +===========================infer_benchmark_params========================== +random_infer_input:[{float32,[3,224,224]}] diff --git a/test_tipc/configs/layoutxlm_ser/train_ptq_infer_python.txt b/test_tipc/configs/layoutxlm_ser/train_ptq_infer_python.txt new file mode 100644 index 0000000000000000000000000000000000000000..47e1e7026bd6bb113b05d70c2bfc7f90879bd485 --- /dev/null +++ b/test_tipc/configs/layoutxlm_ser/train_ptq_infer_python.txt @@ -0,0 +1,21 @@ +===========================train_params=========================== +model_name:layoutxlm_ser_KL +python:python3.7 +Global.pretrained_model: +Global.save_inference_dir:null +infer_model:./inference/ser_LayoutXLM_xfun_zh_infer/ +infer_export:deploy/slim/quantization/quant_kl.py -c test_tipc/configs/layoutxlm_ser/ser_layoutxlm_xfund_zh.yml -o Train.loader.batch_size_per_card=1 Eval.loader.batch_size_per_card=1 +infer_quant:True +inference:ppstructure/kie/predict_kie_token_ser.py --kie_algorithm=LayoutXLM --ser_dict_path=./train_data/XFUND/class_list_xfun.txt +--use_gpu:True|False +--enable_mkldnn:False +--cpu_threads:6 +--rec_batch_num:1 +--use_tensorrt:False +--precision:int8 +--ser_model_dir: +--image_dir:./ppstructure/docs/kie/input/zh_val_42.jpg +null:null +--benchmark:False +null:null +null:null diff --git a/test_tipc/configs/slanet/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt b/test_tipc/configs/slanet/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt new file mode 100644 index 0000000000000000000000000000000000000000..1b4226706b067f65361fd3e79bcbc52e1cf70ad0 --- /dev/null +++ b/test_tipc/configs/slanet/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt @@ -0,0 +1,20 @@ +===========================cpp_infer_params=========================== +model_name:slanet +use_opencv:True +infer_model:./inference/ch_ppstructure_mobile_v2.0_SLANet_infer/ +infer_quant:False +inference:./deploy/cpp_infer/build/ppocr --det_model_dir=./inference/ch_PP-OCRv3_det_infer --rec_model_dir=./inference/ch_PP-OCRv3_rec_infer --output=./output/table --type=structure --table=True --rec_char_dict_path=./ppocr/utils/ppocr_keys_v1.txt --table_char_dict_path=./ppocr/utils/dict/table_structure_dict_ch.txt +--use_gpu:True|False +--enable_mkldnn:False +--cpu_threads:6 +--rec_batch_num:6 +--use_tensorrt:False +--precision:fp32 +--table_model_dir: +--image_dir:./ppstructure/docs/table/table.jpg +null:null +--benchmark:True +--det:True +--rec:True +--cls:False +--use_angle_cls:False \ No newline at end of file diff --git a/test_tipc/configs/slanet/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt b/test_tipc/configs/slanet/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt new file mode 100644 index 0000000000000000000000000000000000000000..45e4e9e858914dd8596cef10625df8160afe45fb --- /dev/null +++ b/test_tipc/configs/slanet/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt @@ -0,0 +1,17 @@ +===========================paddle2onnx_params=========================== +model_name:slanet +python:python3.7 +2onnx: paddle2onnx +--det_model_dir:./inference/ch_ppstructure_mobile_v2.0_SLANet_infer/ +--model_filename:inference.pdmodel +--params_filename:inference.pdiparams +--det_save_file:./inference/ch_ppstructure_mobile_v2.0_SLANet_infer/model.onnx +--rec_model_dir: +--rec_save_file: +--opset_version:10 +--enable_onnx_checker:True +inference:ppstructure/table/predict_structure.py --table_char_dict_path=./ppocr/utils/dict/table_structure_dict_ch.txt +--use_gpu:True|False +--det_model_dir: +--rec_model_dir: +--image_dir:./ppstructure/docs/table/table.jpg \ No newline at end of file diff --git a/test_tipc/configs/slanet/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/slanet/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt new file mode 100644 index 0000000000000000000000000000000000000000..4c9d8d654ad286264e8511c788e278cdbcd52ec9 --- /dev/null +++ b/test_tipc/configs/slanet/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt @@ -0,0 +1,53 @@ +===========================train_params=========================== +model_name:slanet +python:python3.7 +gpu_list:192.168.0.1,192.168.0.2;0,1 +Global.use_gpu:True +Global.auto_cast:fp32 +Global.epoch_num:lite_train_lite_infer=3|whole_train_whole_infer=50 +Global.save_model_dir:./output/ +Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=128 +Global.pretrained_model:./pretrain_models/en_ppstructure_mobile_v2.0_SLANet_train/best_accuracy +train_model_name:latest +train_infer_img_dir:./ppstructure/docs/table/table.jpg +null:null +## +trainer:norm_train +norm_train:tools/train.py -c test_tipc/configs/slanet/SLANet.yml -o +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:null +null:null +## +===========================infer_params=========================== +Global.save_inference_dir:./output/ +Global.checkpoints: +norm_export:tools/export_model.py -c test_tipc/configs/slanet/SLANet.yml -o +quant_export: +fpgm_export: +distill_export:null +export1:null +export2:null +## +infer_model:./inference/en_ppstructure_mobile_v2.0_SLANet_train +infer_export:null +infer_quant:False +inference:ppstructure/table/predict_table.py --det_model_dir=./inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=./inference/en_ppocr_mobile_v2.0_table_rec_infer --rec_char_dict_path=./ppocr/utils/dict/table_dict.txt --table_char_dict_path=./ppocr/utils/dict/table_structure_dict.txt --image_dir=./ppstructure/docs/table/table.jpg --det_limit_side_len=736 --det_limit_type=min --output ./output/table +--use_gpu:False +--enable_mkldnn:False +--cpu_threads:6 +--rec_batch_num:1 +--use_tensorrt:False +--precision:fp32 +--table_model_dir: +--image_dir:./ppstructure/docs/table/table.jpg +null:null +--benchmark:False +null:null +===========================infer_benchmark_params========================== +random_infer_input:[{float32,[3,488,488]}] diff --git a/test_tipc/configs/table_master/train_infer_python.txt b/test_tipc/configs/table_master/train_infer_python.txt index 56b8e636026939ae8cd700308690010e1300d8f6..c3a871731a36fb5434db111cfd68b6eab7ba3f99 100644 --- a/test_tipc/configs/table_master/train_infer_python.txt +++ b/test_tipc/configs/table_master/train_infer_python.txt @@ -37,8 +37,8 @@ export2:null infer_model:null infer_export:null infer_quant:False -inference:ppstructure/table/predict_structure.py --table_char_dict_path=./ppocr/utils/dict/table_master_structure_dict.txt --image_dir=./ppstructure/docs/table/table.jpg --output ./output/table --table_algorithm=TableMaster --table_max_len=480 ---use_gpu:True|False +inference:ppstructure/table/predict_structure.py --table_char_dict_path=./ppocr/utils/dict/table_master_structure_dict.txt --output ./output/table --table_algorithm=TableMaster --table_max_len=480 +--use_gpu:True --enable_mkldnn:False --cpu_threads:6 --rec_batch_num:1 diff --git a/test_tipc/docs/jeston_test_train_inference_python.md b/test_tipc/docs/jeston_test_train_inference_python.md index b25175ed0071dd3728ae22c7588ca20535af0505..22fc21c1cb615fa3e9cb0eb12441db80968a23ed 100644 --- a/test_tipc/docs/jeston_test_train_inference_python.md +++ b/test_tipc/docs/jeston_test_train_inference_python.md @@ -24,12 +24,7 @@ Jetson端基础训练预测功能测试的主程序为`test_inference_inference. ``` - 安装autolog(规范化日志输出工具) ``` - git clone https://github.com/LDOUBLEV/AutoLog - cd AutoLog - pip install -r requirements.txt - python setup.py bdist_wheel - pip install ./dist/auto_log-1.0.0-py3-none-any.whl - cd ../ + pip install https://paddleocr.bj.bcebos.com/libs/auto_log-1.2.0-py3-none-any.whl ``` - 安装PaddleSlim (可选) ``` diff --git a/test_tipc/docs/mac_test_train_inference_python.md b/test_tipc/docs/mac_test_train_inference_python.md index c37291a8fc9b239564adce8f556565f51f2a9475..759ea516430183a1b949ed5b69e24cceac8b6125 100644 --- a/test_tipc/docs/mac_test_train_inference_python.md +++ b/test_tipc/docs/mac_test_train_inference_python.md @@ -1,6 +1,6 @@ # Mac端基础训练预测功能测试 -Mac端基础训练预测功能测试的主程序为`test_train_inference_python.sh`,可以测试基于Python的模型CPU训练,包括裁剪、量化、蒸馏训练,以及评估、CPU推理等基本功能。 +Mac端基础训练预测功能测试的主程序为`test_train_inference_python.sh`,可以测试基于Python的模型CPU训练,包括裁剪、PACT在线量化、蒸馏训练,以及评估、CPU推理等基本功能。 注:Mac端测试用法同linux端测试方法类似,但是无需测试需要在GPU上运行的测试。 @@ -10,7 +10,7 @@ Mac端基础训练预测功能测试的主程序为`test_train_inference_python. | 算法名称 | 模型名称 | 单机单卡(CPU) | 单机多卡 | 多机多卡 | 模型压缩(CPU) | | :---- | :---- | :---- | :---- | :---- | :---- | -| DB | ch_ppocr_mobile_v2.0_det| 正常训练 | - | - | 正常训练:FPGM裁剪、PACT量化
离线量化(无需训练) | +| DB | ch_ppocr_mobile_v2.0_det| 正常训练 | - | - | 正常训练:FPGM裁剪、PACT量化 | - 预测相关:基于训练是否使用量化,可以将训练产出的模型可以分为`正常模型`和`量化模型`,这两类模型对应的预测功能汇总如下, @@ -26,19 +26,14 @@ Mac端基础训练预测功能测试的主程序为`test_train_inference_python. Mac端无GPU,环境准备只需要Python环境即可,安装PaddlePaddle等依赖参考下述文档。 ### 2.1 安装依赖 -- 安装PaddlePaddle >= 2.0 +- 安装PaddlePaddle >= 2.3 - 安装PaddleOCR依赖 ``` pip install -r ../requirements.txt ``` - 安装autolog(规范化日志输出工具) ``` - git clone https://github.com/LDOUBLEV/AutoLog - cd AutoLog - pip install -r requirements.txt - python setup.py bdist_wheel - pip install ./dist/auto_log-1.0.0-py3-none-any.whl - cd ../ + pip install https://paddleocr.bj.bcebos.com/libs/auto_log-1.2.0-py3-none-any.whl ``` - 安装PaddleSlim (可选) ``` @@ -49,53 +44,46 @@ Mac端无GPU,环境准备只需要Python环境即可,安装PaddlePaddle等 ### 2.2 功能测试 -先运行`prepare.sh`准备数据和模型,然后运行`test_train_inference_python.sh`进行测试,最终在```test_tipc/output```目录下生成`python_infer_*.log`格式的日志文件。 +先运行`prepare.sh`准备数据和模型,然后运行`test_train_inference_python.sh`进行测试,最终在```test_tipc/output```目录下生成`,model_name/lite_train_lite_infer/*.log`格式的日志文件。 -`test_train_inference_python.sh`包含5种运行模式,每种模式的运行数据不同,分别用于测试速度和精度,分别是: +`test_train_inference_python.sh`包含基础链条的4种运行模式,每种模式的运行数据不同,分别用于测试速度和精度,分别是: - 模式1:lite_train_lite_infer,使用少量数据训练,用于快速验证训练到预测的走通流程,不验证精度和速度; ```shell # 同linux端运行不同的是,Mac端测试使用新的配置文件mac_ppocr_det_mobile_params.txt, # 配置文件中默认去掉了GPU和mkldnn相关的测试链条 -bash test_tipc/prepare.sh ./test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_mac_cpu_normal_normal_infer_python_mac_cpu.txt 'lite_train_lite_infer' -bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_mac_cpu_normal_normal_infer_python_mac_cpu.txt 'lite_train_lite_infer' +bash test_tipc/prepare.sh ./test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_mac_cpu_normal_normal_infer_python_mac_cpu.txt 'lite_train_lite_infer' +bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_mac_cpu_normal_normal_infer_python_mac_cpu.txt 'lite_train_lite_infer' ``` - 模式2:lite_train_whole_infer,使用少量数据训练,一定量数据预测,用于验证训练后的模型执行预测,预测速度是否合理; ```shell -bash test_tipc/prepare.sh ./test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_mac_cpu_normal_normal_infer_python_mac_cpu.txt 'lite_train_whole_infer' -bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_mac_cpu_normal_normal_infer_python_mac_cpu.txt 'lite_train_whole_infer' +bash test_tipc/prepare.sh ./test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_mac_cpu_normal_normal_infer_python_mac_cpu.txt 'lite_train_whole_infer' +bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_mac_cpu_normal_normal_infer_python_mac_cpu.txt 'lite_train_whole_infer' ``` - 模式3:whole_infer,不训练,全量数据预测,走通开源模型评估、动转静,检查inference model预测时间和精度; ```shell -bash test_tipc/prepare.sh ./test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_mac_cpu_normal_normal_infer_python_mac_cpu.txt 'whole_infer' +bash test_tipc/prepare.sh ./test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_mac_cpu_normal_normal_infer_python_mac_cpu.txt 'whole_infer' # 用法1: -bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_mac_cpu_normal_normal_infer_python_mac_cpu.txt 'whole_infer' +bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_mac_cpu_normal_normal_infer_python_mac_cpu.txt 'whole_infer' # 用法2: 指定GPU卡预测,第三个传入参数为GPU卡号 -bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_mac_cpu_normal_normal_infer_python_mac_cpu.txt 'whole_infer' '1' +bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_mac_cpu_normal_normal_infer_python_mac_cpu.txt 'whole_infer' '1' ``` - 模式4:whole_train_whole_infer,CE: 全量数据训练,全量数据预测,验证模型训练精度,预测精度,预测速度;(Mac端不建议运行此模式) ```shell -bash test_tipc/prepare.sh ./test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_mac_cpu_normal_normal_infer_python_mac_cpu.txt 'whole_train_whole_infer' -bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_mac_cpu_normal_normal_infer_python_mac_cpu.txt 'whole_train_whole_infer' +bash test_tipc/prepare.sh ./test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_mac_cpu_normal_normal_infer_python_mac_cpu.txt 'whole_train_whole_infer' +bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_mac_cpu_normal_normal_infer_python_mac_cpu.txt 'whole_train_whole_infer' ``` -- 模式5:klquant_whole_infer,测试离线量化; -```shell -bash test_tipc/prepare.sh ./test_tipc/configs/ch_ppocr_mobile_v2.0_det_KL/model_linux_gpu_normal_normal_infer_python_mac_cpu.txt 'klquant_whole_infer' -bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/ch_ppocr_mobile_v2.0_det_KL/model_linux_gpu_normal_normal_infer_python_mac_cpu.txt 'klquant_whole_infer' -``` - 运行相应指令后,在`test_tipc/output`文件夹下自动会保存运行日志。如`lite_train_lite_infer`模式下,会运行训练+inference的链条,因此,在`test_tipc/output`文件夹有以下文件: ``` -test_tipc/output/ +test_tipc/output/model_name/lite_train_lite_infer/ |- results_python.log # 运行指令状态的日志 |- norm_train_gpus_-1_autocast_null/ # CPU上正常训练的训练日志和模型保存文件夹 -|- pact_train_gpus_-1_autocast_null/ # CPU上量化训练的训练日志和模型保存文件夹 ...... -|- python_infer_cpu_usemkldnn_False_threads_1_batchsize_1.log # CPU上关闭Mkldnn线程数设置为1,测试batch_size=1条件下的预测运行日志 +|- python_infer_cpu_usemkldnn_False_threads_1_precision_fp32_batchsize_1.log # CPU上关闭Mkldnn线程数设置为1,测试batch_size=1条件下的fp32精度预测运行日志 ...... ``` diff --git a/test_tipc/docs/test_inference_cpp.md b/test_tipc/docs/test_inference_cpp.md index e662f4bacc0b69bd605a79dac0e36c99daac87d5..5d8aeda6c401b48892de1006c2a024447823defa 100644 --- a/test_tipc/docs/test_inference_cpp.md +++ b/test_tipc/docs/test_inference_cpp.md @@ -17,15 +17,15 @@ C++预测功能测试的主程序为`test_inference_cpp.sh`,可以测试基于 运行环境配置请参考[文档](./install.md)的内容配置TIPC的运行环境。 ### 2.1 功能测试 -先运行`prepare.sh`准备数据和模型,然后运行`test_inference_cpp.sh`进行测试,最终在```test_tipc/output```目录下生成`cpp_infer_*.log`后缀的日志文件。 +先运行`prepare.sh`准备数据和模型,然后运行`test_inference_cpp.sh`进行测试,最终在```test_tipc/output/{model_name}/cpp_infer```目录下生成`cpp_infer_*.log`后缀的日志文件。 ```shell -bash test_tipc/prepare.sh ./test_tipc/configs/ch_ppocr_mobile_v2.0_det/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt "cpp_infer" +bash test_tipc/prepare.sh ./test_tipc/configs/ch_PP-OCRv2_rec/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt "cpp_infer" # 用法1: -bash test_tipc/test_inference_cpp.sh test_tipc/configs/ch_ppocr_mobile_v2.0_det/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt +bash test_tipc/test_inference_cpp.sh test_tipc/configs/ch_PP-OCRv2_rec/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt # 用法2: 指定GPU卡预测,第三个传入参数为GPU卡号 -bash test_tipc/test_inference_cpp.sh test_tipc/configs/ch_ppocr_mobile_v2.0_det/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt '1' +bash test_tipc/test_inference_cpp.sh test_tipc/configs/ch_PP-OCRv2_rec/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt '1' ``` 运行预测指令后,在`test_tipc/output`文件夹下自动会保存运行日志,包括以下文件: @@ -33,23 +33,21 @@ bash test_tipc/test_inference_cpp.sh test_tipc/configs/ch_ppocr_mobile_v2.0_det/ ```shell test_tipc/output/ |- results_cpp.log # 运行指令状态的日志 -|- cpp_infer_cpu_usemkldnn_False_threads_1_precision_fp32_batchsize_1.log # CPU上不开启Mkldnn,线程数设置为1,测试batch_size=1条件下的预测运行日志 -|- cpp_infer_cpu_usemkldnn_False_threads_6_precision_fp32_batchsize_1.log # CPU上不开启Mkldnn,线程数设置为6,测试batch_size=1条件下的预测运行日志 -|- cpp_infer_gpu_usetrt_False_precision_fp32_batchsize_1.log # GPU上不开启TensorRT,测试batch_size=1的fp32精度预测日志 -|- cpp_infer_gpu_usetrt_True_precision_fp16_batchsize_1.log # GPU上开启TensorRT,测试batch_size=1的fp16精度预测日志 +|- cpp_infer_cpu_usemkldnn_False_threads_6_precision_fp32_batchsize_6.log # CPU上不开启Mkldnn,线程数设置为6,测试batch_size=6条件下的预测运行日志 +|- cpp_infer_gpu_usetrt_False_precision_fp32_batchsize_6.log # GPU上不开启TensorRT,测试batch_size=6的fp32精度预测日志 ...... ``` 其中results_cpp.log中包含了每条指令的运行状态,如果运行成功会输出: ``` -Run successfully with command - ./deploy/cpp_infer/build/ppocr det --use_gpu=False --enable_mkldnn=False --cpu_threads=6 --det_model_dir=./inference/ch_ppocr_mobile_v2.0_det_infer/ --rec_batch_num=1 --image_dir=./inference/ch_det_data_50/all-sum-510/ --benchmar k=True > ./test_tipc/output/cpp_infer_cpu_usemkldnn_False_threads_6_precision_fp32_batchsize_1.log 2>&1 ! -Run successfully with command - ./deploy/cpp_infer/build/ppocr det --use_gpu=True --use_tensorrt=False --precision=fp32 --det_model_dir=./inference/ch_ppocr_mobile_v2.0_det_infer/ --rec_batch_num=1 --image_dir=./inference/ch_det_data_50/all-sum-510/ --benchmark =True > ./test_tipc/output/cpp_infer_gpu_usetrt_False_precision_fp32_batchsize_1.log 2>&1 ! +[33m Run successfully with command - ch_PP-OCRv2_rec - ./deploy/cpp_infer/build/ppocr --rec_char_dict_path=./ppocr/utils/ppocr_keys_v1.txt --rec_img_h=32 --use_gpu=True --use_tensorrt=False --precision=fp32 --rec_model_dir=./inference/ch_PP-OCRv2_rec_infer/ --rec_batch_num=6 --image_dir=./inference/rec_inference/ --benchmark=True --det=False --rec=True --cls=False --use_angle_cls=False > ./test_tipc/output/ch_PP-OCRv2_rec/cpp_infer/cpp_infer_gpu_usetrt_False_precision_fp32_batchsize_6.log 2>&1 !  + Run successfully with command - ch_PP-OCRv2_rec - ./deploy/cpp_infer/build/ppocr --rec_char_dict_path=./ppocr/utils/ppocr_keys_v1.txt --rec_img_h=32 --use_gpu=False --enable_mkldnn=False --cpu_threads=6 --rec_model_dir=./inference/ch_PP-OCRv2_rec_infer/ --rec_batch_num=6 --image_dir=./inference/rec_inference/ --benchmark=True --det=False --rec=True --cls=False --use_angle_cls=False > ./test_tipc/output/ch_PP-OCRv2_rec/cpp_infer/cpp_infer_cpu_usemkldnn_False_threads_6_precision_fp32_batchsize_6.log 2>&1 !  ...... ``` 如果运行失败,会输出: ``` -Run failed with command - ./deploy/cpp_infer/build/ppocr det --use_gpu=True --use_tensorrt=True --precision=fp32 --det_model_dir=./inference/ch_ppocr_mobile_v2.0_det_infer/ --rec_batch_num=1 --image_dir=./inference/ch_det_data_50/all-sum-510/ --benchmark=True > ./test_tipc/output/cpp_infer_gpu_usetrt_True_precision_fp32_batchsize_1.log 2>&1 ! -Run failed with command - ./deploy/cpp_infer/build/ppocr det --use_gpu=True --use_tensorrt=True --precision=fp16 --det_model_dir=./inference/ch_ppocr_mobile_v2.0_det_infer/ --rec_batch_num=1 --image_dir=./inference/ch_det_data_50/all-sum-510/ --benchmark=True > ./test_tipc/output/cpp_infer_gpu_usetrt_True_precision_fp16_batchsize_1.log 2>&1 ! +Run failed with command - ch_PP-OCRv2_rec - ./deploy/cpp_infer/build/ppocr --rec_char_dict_path=./ppocr/utils/ppocr_keys_v1.txt --rec_img_h=32 --use_gpu=True --use_tensorrt=False --precision=fp32 --rec_model_dir=./inference/ch_PP-OCRv2_rec_infer/ --rec_batch_num=6 --image_dir=./inference/rec_inference/ --benchmark=True --det=False --rec=True --cls=False --use_angle_cls=False > ./test_tipc/output/ch_PP-OCRv2_rec/cpp_infer/cpp_infer_gpu_usetrt_False_precision_fp32_batchsize_6.log 2>&1 ! +Run failed with command - ch_PP-OCRv2_rec - ./deploy/cpp_infer/build/ppocr --rec_char_dict_path=./ppocr/utils/ppocr_keys_v1.txt --rec_img_h=32 --use_gpu=False --enable_mkldnn=False --cpu_threads=6 --rec_model_dir=./inference/ch_PP-OCRv2_rec_infer/ --rec_batch_num=6 --image_dir=./inference/rec_inference/ --benchmark=True --det=False --rec=True --cls=False --use_angle_cls=False > ./test_tipc/output/ch_PP-OCRv2_rec/cpp_infer/cpp_infer_cpu_usemkldnn_False_threads_6_precision_fp32_batchsize_6.log 2>&1 ! ...... ``` 可以很方便的根据results_cpp.log中的内容判定哪一个指令运行错误。 diff --git a/test_tipc/docs/test_paddle2onnx.md b/test_tipc/docs/test_paddle2onnx.md index df2734771e9252a40811c42ead03abbff1b7a1a3..299621d01122995434646351edfd524a0aa3206a 100644 --- a/test_tipc/docs/test_paddle2onnx.md +++ b/test_tipc/docs/test_paddle2onnx.md @@ -15,29 +15,30 @@ PaddleServing预测功能测试的主程序为`test_paddle2onnx.sh`,可以测 ## 2. 测试流程 ### 2.1 功能测试 -先运行`prepare.sh`准备数据和模型,然后运行`test_paddle2onnx.sh`进行测试,最终在```test_tipc/output```目录下生成`paddle2onnx_infer_*.log`后缀的日志文件。 +先运行`prepare.sh`准备数据和模型,然后运行`test_paddle2onnx.sh`进行测试,最终在```test_tipc/output/{model_name}/paddle2onnx```目录下生成`paddle2onnx_infer_*.log`后缀的日志文件。 ```shell -bash test_tipc/prepare.sh ./test_tipc/configs/ppocr_det_mobile/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt "paddle2onnx_infer" +bash test_tipc/prepare.sh ./test_tipc/configs/ch_PP-OCRv2_det/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt "paddle2onnx_infer" # 用法: -bash test_tipc/test_paddle2onnx.sh ./test_tipc/configs/ppocr_det_mobile/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt +bash test_tipc/test_paddle2onnx.sh ./test_tipc/configs/ch_PP-OCRv2_det/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt ``` #### 运行结果 -各测试的运行情况会打印在 `test_tipc/output/results_paddle2onnx.log` 中: +各测试的运行情况会打印在 `test_tipc/output/{model_name}/paddle2onnx/results_paddle2onnx.log` 中: 运行成功时会输出: ``` -Run successfully with command - paddle2onnx --model_dir=./inference/ch_ppocr_mobile_v2.0_det_infer/ --model_filename=inference.pdmodel --params_filename=inference.pdiparams --save_file=./inference/det_mobile_onnx/model.onnx --opset_version=10 --enable_onnx_checker=True! -Run successfully with command - python test_tipc/onnx_inference/predict_det.py --use_gpu=False --image_dir=./inference/ch_det_data_50/all-sum-510/ --det_model_dir=./inference/det_mobile_onnx/model.onnx 2>&1 ! +Run successfully with command - ch_PP-OCRv2_det - paddle2onnx --model_dir=./inference/ch_PP-OCRv2_det_infer/ --model_filename=inference.pdmodel --params_filename=inference.pdiparams --save_file=./inference/det_v2_onnx/model.onnx --opset_version=10 --enable_onnx_checker=True! +Run successfully with command - ch_PP-OCRv2_det - python3.7 tools/infer/predict_det.py --use_gpu=True --image_dir=./inference/ch_det_data_50/all-sum-510/ --det_model_dir=./inference/det_v2_onnx/model.onnx --use_onnx=True > ./test_tipc/output/ch_PP-OCRv2_det/paddle2onnx/paddle2onnx_infer_gpu.log 2>&1 ! +Run successfully with command - ch_PP-OCRv2_det - python3.7 tools/infer/predict_det.py --use_gpu=False --image_dir=./inference/ch_det_data_50/all-sum-510/ --det_model_dir=./inference/det_v2_onnx/model.onnx --use_onnx=True > ./test_tipc/output/ch_PP-OCRv2_det/paddle2onnx/paddle2onnx_infer_cpu.log 2>&1 ! ``` 运行失败时会输出: ``` -Run failed with command - paddle2onnx --model_dir=./inference/ch_ppocr_mobile_v2.0_det_infer/ --model_filename=inference.pdmodel --params_filename=inference.pdiparams --save_file=./inference/det_mobile_onnx/model.onnx --opset_version=10 --enable_onnx_checker=True! +Run failed with command - ch_PP-OCRv2_det - paddle2onnx --model_dir=./inference/ch_PP-OCRv2_det_infer/ --model_filename=inference.pdmodel --params_filename=inference.pdiparams --save_file=./inference/det_v2_onnx/model.onnx --opset_version=10 --enable_onnx_checker=True! ... ``` diff --git a/test_tipc/docs/test_ptq_inference_python.md b/test_tipc/docs/test_ptq_inference_python.md new file mode 100644 index 0000000000000000000000000000000000000000..7887c0b5c93decac61f56d8c8b92018f40c78b32 --- /dev/null +++ b/test_tipc/docs/test_ptq_inference_python.md @@ -0,0 +1,51 @@ +# Linux GPU/CPU KL离线量化训练推理测试 + +Linux GPU/CPU KL离线量化训练推理测试的主程序为`test_ptq_inference_python.sh`,可以测试基于Python的模型训练、评估、推理等基本功能。 + +## 1. 测试结论汇总 +- 训练相关: + +| 算法名称 | 模型名称 | 单机单卡 | +| :----: | :----: | :----: | +| | model_name | KL离线量化训练 | + +- 推理相关: + +| 算法名称 | 模型名称 | device_CPU | device_GPU | batchsize | +| :----: | :----: | :----: | :----: | :----: | +| | model_name | 支持 | 支持 | 1 | + +## 2. 测试流程 + +### 2.1 准备数据和模型 + +先运行`prepare.sh`准备数据和模型,然后运行`test_ptq_inference_python.sh`进行测试,最终在```test_tipc/output/{model_name}/whole_infer```目录下生成`python_infer_*.log`后缀的日志文件。 + +```shell +bash test_tipc/prepare.sh ./test_tipc/configs/ch_PP-OCRv2_det/train_ptq_infer_python.txt "whole_infer" + +# 用法: +bash test_tipc/test_ptq_inference_python.sh ./test_tipc/configs/ch_PP-OCRv2_det/train_ptq_infer_python.txt "whole_infer" +``` + +#### 运行结果 + +各测试的运行情况会打印在 `test_tipc/output/{model_name}/paddle2onnx/results_paddle2onnx.log` 中: +运行成功时会输出: + +``` +Run successfully with command - ch_PP-OCRv2_det_KL - python3.7 deploy/slim/quantization/quant_kl.py -c configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml -o Global.pretrained_model=./inference/ch_PP-OCRv2_det_infer/ Global.save_inference_dir=./inference/ch_PP-OCRv2_det_infer/_klquant > ./test_tipc/output/ch_PP-OCRv2_det_KL/whole_infer/whole_infer_export_0.log 2>&1 ! +Run successfully with command - ch_PP-OCRv2_det_KL - python3.7 tools/infer/predict_det.py --use_gpu=False --enable_mkldnn=False --cpu_threads=6 --det_model_dir=./inference/ch_PP-OCRv2_det_infer/_klquant --rec_batch_num=1 --image_dir=./inference/ch_det_data_50/all-sum-510/ --precision=int8 > ./test_tipc/output/ch_PP-OCRv2_det_KL/whole_infer/python_infer_cpu_usemkldnn_False_threads_6_precision_int8_batchsize_1.log 2>&1 ! +Run successfully with command - ch_PP-OCRv2_det_KL - python3.7 tools/infer/predict_det.py --use_gpu=True --use_tensorrt=False --precision=int8 --det_model_dir=./inference/ch_PP-OCRv2_det_infer/_klquant --rec_batch_num=1 --image_dir=./inference/ch_det_data_50/all-sum-510/ > ./test_tipc/output/ch_PP-OCRv2_det_KL/whole_infer/python_infer_gpu_usetrt_False_precision_int8_batchsize_1.log 2>&1 ! +``` + +运行失败时会输出: + +``` +Run failed with command - ch_PP-OCRv2_det_KL - python3.7 deploy/slim/quantization/quant_kl.py -c configs/det/ch_PP-OCRv2/ch_PP-OCRv2_det_cml.yml -o Global.pretrained_model=./inference/ch_PP-OCRv2_det_infer/ Global.save_inference_dir=./inference/ch_PP-OCRv2_det_infer/_klquant > ./test_tipc/output/ch_PP-OCRv2_det_KL/whole_infer/whole_infer_export_0.log 2>&1 ! +... +``` + +## 3. 更多教程 + +本文档为功能测试用,更详细的量化使用教程请参考:[量化](../../deploy/slim/quantization/README.md) diff --git a/test_tipc/docs/test_serving.md b/test_tipc/docs/test_serving.md index 71f01c0d5ff47004d70baa17b404c10714a6fb64..ef38888784b600233fe85afe3c1064caf12173d4 100644 --- a/test_tipc/docs/test_serving.md +++ b/test_tipc/docs/test_serving.md @@ -18,71 +18,44 @@ PaddleServing预测功能测试的主程序为`test_serving_infer_python.sh`和` ### 2.1 功能测试 **python serving** -先运行`prepare.sh`准备数据和模型,然后运行`test_serving_infer_python.sh`进行测试,最终在```test_tipc/output```目录下生成`serving_infer_python*.log`后缀的日志文件。 +先运行`prepare.sh`准备数据和模型,然后运行`test_serving_infer_python.sh`进行测试,最终在```test_tipc/output/{model_name}/serving_infer/python```目录下生成`python_*.log`后缀的日志文件。 ```shell -bash test_tipc/prepare.sh ./test_tipc/configs/ch_ppocr_mobile_v2.0_det/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt "serving_infer" +bash test_tipc/prepare.sh ./test_tipc/configs/ch_PP-OCRv2/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt "serving_infer" # 用法: -bash test_tipc/test_serving_infer_python.sh ./test_tipc/configs/ch_ppocr_mobile_v2.0_det/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt "serving_infer" +bash test_tipc/test_serving_infer_python.sh ./test_tipc/configs/ch_PP-OCRv2/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt "serving_infer" ``` **cpp serving** -先运行`prepare.sh`准备数据和模型,然后运行`test_serving_infer_cpp.sh`进行测试,最终在```test_tipc/output```目录下生成`serving_infer_cpp*.log`后缀的日志文件。 +先运行`prepare.sh`准备数据和模型,然后运行`test_serving_infer_cpp.sh`进行测试,最终在```test_tipc/output/{model_name}/serving_infer/cpp```目录下生成`cpp_*.log`后缀的日志文件。 ```shell -bash test_tipc/prepare.sh ./test_tipc/configs/ch_ppocr_mobile_v2.0/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt "serving_infer" +bash test_tipc/prepare.sh ./test_tipc/configs/ch_PP-OCRv2/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt "serving_infer" # 用法: -bash test_tipc/test_serving_infer_cpp.sh ./test_tipc/configs/ch_ppocr_mobile_v2.0/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt "serving_infer" +bash test_tipc/test_serving_infer_cpp.sh ./test_tipc/configs/ch_PP-OCRv2/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt "serving_infer" ``` #### 运行结果 -各测试的运行情况会打印在 `test_tipc/output/results_serving.log` 中: +各测试的运行情况会打印在 `test_tipc/output/{model_name}/serving_infer/python(cpp)/results_python(cpp)_serving.log` 中: 运行成功时会输出: ``` -Run successfully with command - python3.7 pipeline_http_client.py --image_dir=../../doc/imgs > ../../tests/output/server_infer_cpu_usemkldnn_True_threads_1_batchsize_1.log 2>&1 ! -Run successfully with command - xxxxx +Run successfully with command - ch_PP-OCRv2_rec - nohup python3.7 web_service_rec.py --config=config.yml --opt op.rec.concurrency="1" op.det.local_service_conf.devices= op.det.local_service_conf.use_mkldnn=False op.det.local_service_conf.thread_num=6 op.rec.local_service_conf.model_config=ppocr_rec_v2_serving > ./test_tipc/output/ch_PP-OCRv2_rec/serving_infer/python/python_server_cpu_usemkldnn_False_threads_6.log 2>&1 &! +Run successfully with command - ch_PP-OCRv2_rec - python3.7 pipeline_http_client.py --det=False --image_dir=../../inference/rec_inference > ./test_tipc/output/ch_PP-OCRv2_rec/serving_infer/python/python_client_cpu_pipeline_http_usemkldnn_False_threads_6_batchsize_1.log 2>&1 ! ... ``` 运行失败时会输出: ``` -Run failed with command - python3.7 pipeline_http_client.py --image_dir=../../doc/imgs > ../../tests/output/server_infer_cpu_usemkldnn_True_threads_1_batchsize_1.log 2>&1 ! -Run failed with command - python3.7 pipeline_http_client.py --image_dir=../../doc/imgs > ../../tests/output/server_infer_cpu_usemkldnn_True_threads_6_batchsize_1.log 2>&1 ! -Run failed with command - xxxxx +Run failed with command - ch_PP-OCRv2_rec - nohup python3.7 web_service_rec.py --config=config.yml --opt op.rec.concurrency="1" op.det.local_service_conf.devices= op.det.local_service_conf.use_mkldnn=False op.det.local_service_conf.thread_num=6 op.rec.local_service_conf.model_config=ppocr_rec_v2_serving > ./test_tipc/output/ch_PP-OCRv2_rec/serving_infer/python/python_server_cpu_usemkldnn_False_threads_6.log 2>&1 &! +Run failed with command - ch_PP-OCRv2_rec - python3.7 pipeline_http_client.py --det=False --image_dir=../../inference/rec_inference > ./test_tipc/output/ch_PP-OCRv2_rec/serving_infer/python/python_client_cpu_pipeline_http_usemkldnn_False_threads_6_batchsize_1.log 2>&1 ! ... ``` -详细的预测结果会存在 test_tipc/output/ 文件夹下,例如`server_infer_gpu_usetrt_True_precision_fp16_batchsize_1.log`中会返回检测框的坐标: - -``` -{'err_no': 0, 'err_msg': '', 'key': ['dt_boxes'], 'value': ['[[[ 78. 642.]\n [409. 640.]\n [409. 657.]\n -[ 78. 659.]]\n\n [[ 75. 614.]\n [211. 614.]\n [211. 635.]\n [ 75. 635.]]\n\n -[[103. 554.]\n [135. 554.]\n [135. 575.]\n [103. 575.]]\n\n [[ 75. 531.]\n -[347. 531.]\n [347. 549.]\n [ 75. 549.] ]\n\n [[ 76. 503.]\n [309. 498.]\n -[309. 521.]\n [ 76. 526.]]\n\n [[163. 462.]\n [317. 462.]\n [317. 493.]\n -[163. 493.]]\n\n [[324. 431.]\n [414. 431.]\n [414. 452.]\n [324. 452.]]\n\n -[[ 76. 412.]\n [208. 408.]\n [209. 424.]\n [ 76. 428.]]\n\n [[307. 409.]\n -[428. 409.]\n [428. 426.]\n [307 . 426.]]\n\n [[ 74. 385.]\n [217. 382.]\n -[217. 400.]\n [ 74. 403.]]\n\n [[308. 381.]\n [427. 380.]\n [427. 400.]\n -[308. 401.]]\n\n [[ 74. 363.]\n [195. 362.]\n [195. 378.]\n [ 74. 379.]]\n\n -[[303. 359.]\n [423. 357.]\n [423. 375.]\n [303. 377.]]\n\n [[ 70. 336.]\n -[239. 334.]\n [239. 354.]\ n [ 70. 356.]]\n\n [[ 70. 312.]\n [204. 310.]\n -[204. 327.]\n [ 70. 330.]]\n\n [[303. 308.]\n [419. 306.]\n [419. 326.]\n -[303. 328.]]\n\n [[113. 2 72.]\n [246. 270.]\n [247. 299.]\n [113. 301.]]\n\n - [[361. 269.]\n [384. 269.]\n [384. 296.]\n [361. 296.]]\n\n [[ 70. 250.]\n - [243. 246.]\n [243. 265.]\n [ 70. 269.]]\n\n [[ 65. 221.]\n [187. 220.]\n -[187. 240.]\n [ 65. 241.]]\n\n [[337. 216.]\n [382. 216.]\n [382. 240.]\n -[337. 240.]]\n\n [ [ 65. 196.]\n [247. 193.]\n [247. 213.]\n [ 65. 216.]]\n\n -[[296. 197.]\n [423. 191.]\n [424. 209.]\n [296. 215.]]\n\n [[ 65. 167.]\n [244. 167.]\n -[244. 186.]\n [ 65. 186.]]\n\n [[ 67. 139.]\n [290. 139.]\n [290. 159.]\n [ 67. 159.]]\n\n -[[ 68. 113.]\n [410. 113.]\n [410. 128.]\n [ 68. 129.] ]\n\n [[277. 87.]\n [416. 87.]\n -[416. 108.]\n [277. 108.]]\n\n [[ 79. 28.]\n [132. 28.]\n [132. 62.]\n [ 79. 62.]]\n\n -[[163. 17.]\n [410. 14.]\n [410. 50.]\n [163. 53.]]]']} -``` +详细的预测结果会存在 test_tipc/output/{model_name}/serving_infer/python(cpp)/ 文件夹下 ## 3. 更多教程 diff --git a/test_tipc/docs/test_train_inference_python.md b/test_tipc/docs/test_train_inference_python.md index 99de9400797493f429f8176a9b6b374a76df4872..d1dbd8ee47a4dc7fb4c0bb3d26a920aab1c7ff72 100644 --- a/test_tipc/docs/test_train_inference_python.md +++ b/test_tipc/docs/test_train_inference_python.md @@ -1,6 +1,6 @@ # Linux端基础训练预测功能测试 -Linux端基础训练预测功能测试的主程序为`test_train_inference_python.sh`,可以测试基于Python的模型训练、评估、推理等基本功能,包括裁剪、量化、蒸馏。 +Linux端基础训练预测功能测试的主程序为`test_train_inference_python.sh`,可以测试基于Python的模型训练、评估、推理等基本功能,包括PACT在线量化。 - Mac端基础训练预测功能测试参考[链接](./mac_test_train_inference_python.md) - Windows端基础训练预测功能测试参考[链接](./win_test_train_inference_python.md) @@ -11,13 +11,14 @@ Linux端基础训练预测功能测试的主程序为`test_train_inference_pytho | 算法名称 | 模型名称 | 单机单卡 | 单机多卡 | 多机多卡 | 模型压缩(单机多卡) | | :---- | :---- | :---- | :---- | :---- | :---- | -| DB | ch_ppocr_mobile_v2.0_det| 正常训练
混合精度 | 正常训练
混合精度 | 正常训练
混合精度 | 正常训练:FPGM裁剪、PACT量化
离线量化(无需训练) | -| DB | ch_ppocr_server_v2.0_det| 正常训练
混合精度 | 正常训练
混合精度 | 正常训练
混合精度 | 正常训练:FPGM裁剪、PACT量化
离线量化(无需训练) | -| CRNN | ch_ppocr_mobile_v2.0_rec| 正常训练
混合精度 | 正常训练
混合精度 | 正常训练
混合精度 | 正常训练:PACT量化
离线量化(无需训练) | -| CRNN | ch_ppocr_server_v2.0_rec| 正常训练
混合精度 | 正常训练
混合精度 | 正常训练
混合精度 | 正常训练:PACT量化
离线量化(无需训练) | -|PP-OCR| ch_ppocr_mobile_v2.0| 正常训练
混合精度 | 正常训练
混合精度 | 正常训练
混合精度 | - | -|PP-OCR| ch_ppocr_server_v2.0| 正常训练
混合精度 | 正常训练
混合精度 | 正常训练
混合精度 | - | +| DB | ch_ppocr_mobile_v2_0_det| 正常训练
混合精度 | 正常训练
混合精度 | 正常训练
混合精度 | 正常训练:FPGM裁剪、PACT量化 | +| DB | ch_ppocr_server_v2_0_det| 正常训练
混合精度 | 正常训练
混合精度 | 正常训练
混合精度 | 正常训练:FPGM裁剪、PACT量化 | +| CRNN | ch_ppocr_mobile_v2_0_rec| 正常训练
混合精度 | 正常训练
混合精度 | 正常训练
混合精度 | 正常训练:PACT量化 | +| CRNN | ch_ppocr_server_v2_0_rec| 正常训练
混合精度 | 正常训练
混合精度 | 正常训练
混合精度 | 正常训练:PACT量化 | +|PP-OCR| ch_ppocr_mobile_v2_0| 正常训练
混合精度 | 正常训练
混合精度 | 正常训练
混合精度 | - | +|PP-OCR| ch_ppocr_server_v2_0| 正常训练
混合精度 | 正常训练
混合精度 | 正常训练
混合精度 | - | |PP-OCRv2| ch_PP-OCRv2 | 正常训练
混合精度 | 正常训练
混合精度 | 正常训练
混合精度 | - | +|PP-OCRv3| ch_PP-OCRv3 | 正常训练
混合精度 | 正常训练
混合精度 | 正常训练
混合精度 | - | - 预测相关:基于训练是否使用量化,可以将训练产出的模型可以分为`正常模型`和`量化模型`,这两类模型对应的预测功能汇总如下, @@ -35,19 +36,14 @@ Linux端基础训练预测功能测试的主程序为`test_train_inference_pytho 运行环境配置请参考[文档](./install.md)的内容配置TIPC的运行环境。 ### 2.1 安装依赖 -- 安装PaddlePaddle >= 2.0 +- 安装PaddlePaddle >= 2.3 - 安装PaddleOCR依赖 ``` pip3 install -r ../requirements.txt ``` - 安装autolog(规范化日志输出工具) ``` - git clone https://github.com/LDOUBLEV/AutoLog - cd AutoLog - pip3 install -r requirements.txt - python3 setup.py bdist_wheel - pip3 install ./dist/auto_log-1.0.0-py3-none-any.whl - cd ../ + pip3 install https://paddleocr.bj.bcebos.com/libs/auto_log-1.2.0-py3-none-any.whl ``` - 安装PaddleSlim (可选) ``` @@ -57,60 +53,57 @@ Linux端基础训练预测功能测试的主程序为`test_train_inference_pytho ### 2.2 功能测试 -先运行`prepare.sh`准备数据和模型,然后运行`test_train_inference_python.sh`进行测试,最终在```test_tipc/output```目录下生成`python_infer_*.log`格式的日志文件。 +#### 2.2.1 基础训练推理链条 +先运行`prepare.sh`准备数据和模型,然后运行`test_train_inference_python.sh`进行测试,最终在```test_tipc/output```目录下生成`,model_name/lite_train_lite_infer/*.log`格式的日志文件。 -`test_train_inference_python.sh`包含5种运行模式,每种模式的运行数据不同,分别用于测试速度和精度,分别是: +`test_train_inference_python.sh`包含基础链条的4种运行模式,每种模式的运行数据不同,分别用于测试速度和精度,分别是: - 模式1:lite_train_lite_infer,使用少量数据训练,用于快速验证训练到预测的走通流程,不验证精度和速度; ```shell -bash test_tipc/prepare.sh ./test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_infer_python.txt 'lite_train_lite_infer' -bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_infer_python.txt 'lite_train_lite_infer' +bash test_tipc/prepare.sh ./test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_infer_python.txt 'lite_train_lite_infer' +bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_infer_python.txt 'lite_train_lite_infer' ``` - 模式2:lite_train_whole_infer,使用少量数据训练,一定量数据预测,用于验证训练后的模型执行预测,预测速度是否合理; ```shell -bash test_tipc/prepare.sh ./test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_infer_python.txt 'lite_train_whole_infer' -bash test_tipc/test_train_inference_python.sh ../test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_infer_python.txt 'lite_train_whole_infer' +bash test_tipc/prepare.sh ./test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_infer_python.txt 'lite_train_whole_infer' +bash test_tipc/test_train_inference_python.sh ../test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_infer_python.txt 'lite_train_whole_infer' ``` - 模式3:whole_infer,不训练,全量数据预测,走通开源模型评估、动转静,检查inference model预测时间和精度; ```shell -bash test_tipc/prepare.sh ./test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_infer_python.txt 'whole_infer' +bash test_tipc/prepare.sh ./test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_infer_python.txt 'whole_infer' # 用法1: -bash test_tipc/test_train_inference_python.sh ../test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_infer_python.txt 'whole_infer' +bash test_tipc/test_train_inference_python.sh ../test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_infer_python.txt 'whole_infer' # 用法2: 指定GPU卡预测,第三个传入参数为GPU卡号 -bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_infer_python.txt 'whole_infer' '1' +bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_infer_python.txt 'whole_infer' '1' ``` - 模式4:whole_train_whole_infer,CE: 全量数据训练,全量数据预测,验证模型训练精度,预测精度,预测速度; ```shell -bash test_tipc/prepare.sh ./test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_infer_python.txt 'whole_train_whole_infer' -bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_infer_python.txt 'whole_train_whole_infer' +bash test_tipc/prepare.sh ./test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_infer_python.txt 'whole_train_whole_infer' +bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_infer_python.txt 'whole_train_whole_infer' ``` -- 模式5:klquant_whole_infer,测试离线量化; -```shell -bash test_tipc/prepare.sh ./test_tipc/configs/ch_ppocr_mobile_v2.0_det_KL/model_linux_gpu_normal_normal_infer_python_linux_gpu_cpu.txt 'klquant_whole_infer' -bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/ch_ppocr_mobile_v2.0_det_KL/model_linux_gpu_normal_normal_infer_python_linux_gpu_cpu.txt 'klquant_whole_infer' -``` - 运行相应指令后,在`test_tipc/output`文件夹下自动会保存运行日志。如'lite_train_lite_infer'模式下,会运行训练+inference的链条,因此,在`test_tipc/output`文件夹有以下文件: ``` -test_tipc/output/ +test_tipc/output/model_name/lite_train_lite_infer/ |- results_python.log # 运行指令状态的日志 -|- norm_train_gpus_0_autocast_null/ # GPU 0号卡上正常训练的训练日志和模型保存文件夹 -|- pact_train_gpus_0_autocast_null/ # GPU 0号卡上量化训练的训练日志和模型保存文件夹 +|- norm_train_gpus_0_autocast_null/ # GPU 0号卡上正常单机单卡训练的训练日志和模型保存文件夹 +|- norm_train_gpus_0,1_autocast_null/ # GPU 0,1号卡上正常单机多卡训练的训练日志和模型保存文件夹 ...... -|- python_infer_cpu_usemkldnn_True_threads_1_batchsize_1.log # CPU上开启Mkldnn线程数设置为1,测试batch_size=1条件下的预测运行日志 -|- python_infer_gpu_usetrt_True_precision_fp16_batchsize_1.log # GPU上开启TensorRT,测试batch_size=1的半精度预测日志 +|- python_infer_cpu_usemkldnn_False_threads_6_precision_fp32_batchsize_1.log # CPU上关闭Mkldnn线程数设置为6,测试batch_size=1条件下的fp32精度预测运行日志 +|- python_infer_gpu_usetrt_False_precision_fp32_batchsize_1.log # GPU上关闭TensorRT,测试batch_size=1的fp32精度预测日志 ...... ``` 其中`results_python.log`中包含了每条指令的运行状态,如果运行成功会输出: ``` -Run successfully with command - python3.7 tools/train.py -c tests/configs/det_mv3_db.yml -o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained Global.use_gpu=True Global.save_model_dir=./tests/output/norm_train_gpus_0_autocast_null Global.epoch_num=1 Train.loader.batch_size_per_card=2 ! -Run successfully with command - python3.7 tools/export_model.py -c tests/configs/det_mv3_db.yml -o Global.pretrained_model=./tests/output/norm_train_gpus_0_autocast_null/latest Global.save_inference_dir=./tests/output/norm_train_gpus_0_autocast_null! +[33m Run successfully with command - ch_ppocr_mobile_v2_0_det - python3.7 tools/train.py -c configs/det/ch_ppocr_v2_0/ch_det_mv3_db_v2_0.yml -o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained Global.use_gpu=True Global.save_model_dir=./test_tipc/output/ch_ppocr_mobile_v2_0_det/lite_train_lite_infer/norm_train_gpus_0_autocast_null Global.epoch_num=100 Train.loader.batch_size_per_card=2 !  + Run successfully with command - ch_ppocr_mobile_v2_0_det - python3.7 tools/export_model.py -c configs/det/ch_ppocr_v2_0/ch_det_mv3_db_v2_0.yml -o Global.checkpoints=./test_tipc/output/ch_ppocr_mobile_v2_0_det/lite_train_lite_infer/norm_train_gpus_0_autocast_null/latest Global.save_inference_dir=./test_tipc/output/ch_ppocr_mobile_v2_0_det/lite_train_lite_infer/norm_train_gpus_0_autocast_null > ./test_tipc/output/ch_ppocr_mobile_v2_0_det/lite_train_lite_infer/norm_train_gpus_0_autocast_null_nodes_1_export.log 2>&1 !  + Run successfully with command - ch_ppocr_mobile_v2_0_det - python3.7 tools/infer/predict_det.py --use_gpu=True --use_tensorrt=False --precision=fp32 --det_model_dir=./test_tipc/output/ch_ppocr_mobile_v2_0_det/lite_train_lite_infer/norm_train_gpus_0_autocast_null --rec_batch_num=1 --image_dir=./train_data/icdar2015/text_localization/ch4_test_images/ --benchmark=True > ./test_tipc/output/ch_ppocr_mobile_v2_0_det/lite_train_lite_infer/python_infer_gpu_usetrt_False_precision_fp32_batchsize_1.log 2>&1 !  + Run successfully with command - ch_ppocr_mobile_v2_0_det - python3.7 tools/infer/predict_det.py --use_gpu=False --enable_mkldnn=False --cpu_threads=6 --det_model_dir=./test_tipc/output/ch_ppocr_mobile_v2_0_det/lite_train_lite_infer/norm_train_gpus_0_autocast_null --rec_batch_num=1 --image_dir=./train_data/icdar2015/text_localization/ch4_test_images/ --benchmark=True --precision=fp32 > ./test_tipc/output/ch_ppocr_mobile_v2_0_det/lite_train_lite_infer/python_infer_cpu_usemkldnn_False_threads_6_precision_fp32_batchsize_1.log 2>&1 !  ...... ``` 如果运行失败,会输出: @@ -121,6 +114,22 @@ Run failed with command - python3.7 tools/export_model.py -c tests/configs/det_m ``` 可以很方便的根据`results_python.log`中的内容判定哪一个指令运行错误。 +#### 2.2.2 PACT在线量化链条 +此外,`test_train_inference_python.sh`还包含PACT在线量化模式,命令如下: +以ch_PP-OCRv2_det为例,如需测试其他模型更换配置即可。 + +```shell +bash test_tipc/prepare.sh ./test_tipc/configs/ch_PP-OCRv2_det/train_pact_infer_python.txt 'lite_train_lite_infer' +bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/ch_PP-OCRv2_det/train_pact_infer_python.txt 'lite_train_lite_infer' +``` +#### 2.2.3 混合精度训练链条 +此外,`test_train_inference_python.sh`还包含混合精度训练模式,命令如下: +以ch_PP-OCRv2_det为例,如需测试其他模型更换配置即可。 + +```shell +bash test_tipc/prepare.sh ./test_tipc/configs/ch_PP-OCRv2_det/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt 'lite_train_lite_infer' +bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/ch_PP-OCRv2_det/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt 'lite_train_lite_infer' +``` ### 2.3 精度测试 diff --git a/test_tipc/docs/win_test_train_inference_python.md b/test_tipc/docs/win_test_train_inference_python.md index 6e3ce93bb3123133075b9d65c64850a87de5f828..d631c38873867ef1fa6e9a03582df26b59e309a5 100644 --- a/test_tipc/docs/win_test_train_inference_python.md +++ b/test_tipc/docs/win_test_train_inference_python.md @@ -8,7 +8,7 @@ Windows端基础训练预测功能测试的主程序为`test_train_inference_pyt | 算法名称 | 模型名称 | 单机单卡 | 单机多卡 | 多机多卡 | 模型压缩(单机多卡) | | :---- | :---- | :---- | :---- | :---- | :---- | -| DB | ch_ppocr_mobile_v2.0_det| 正常训练
混合精度 | - | - | 正常训练:FPGM裁剪、PACT量化
离线量化(无需训练) | +| DB | ch_ppocr_mobile_v2_0_det| 正常训练
混合精度 | - | - | 正常训练:FPGM裁剪、PACT量化 | - 预测相关:基于训练是否使用量化,可以将训练产出的模型可以分为`正常模型`和`量化模型`,这两类模型对应的预测功能汇总如下: @@ -29,19 +29,14 @@ Windows端基础训练预测功能测试的主程序为`test_train_inference_pyt ### 2.1 安装依赖 -- 安装PaddlePaddle >= 2.0 +- 安装PaddlePaddle >= 2.3 - 安装PaddleOCR依赖 ``` pip install -r ../requirements.txt ``` - 安装autolog(规范化日志输出工具) ``` - git clone https://github.com/LDOUBLEV/AutoLog - cd AutoLog - pip install -r requirements.txt - python setup.py bdist_wheel - pip install ./dist/auto_log-1.0.0-py3-none-any.whl - cd ../ + pip install https://paddleocr.bj.bcebos.com/libs/auto_log-1.2.0-py3-none-any.whl ``` - 安装PaddleSlim (可选) ``` @@ -51,54 +46,46 @@ Windows端基础训练预测功能测试的主程序为`test_train_inference_pyt ### 2.2 功能测试 -先运行`prepare.sh`准备数据和模型,然后运行`test_train_inference_python.sh`进行测试,最终在```test_tipc/output```目录下生成`python_infer_*.log`格式的日志文件。 +先运行`prepare.sh`准备数据和模型,然后运行`test_train_inference_python.sh`进行测试,最终在```test_tipc/output```目录下生成`,model_name/lite_train_lite_infer/*.log`格式的日志文件。 -`test_train_inference_python.sh`包含5种运行模式,每种模式的运行数据不同,分别用于测试速度和精度,分别是: +`test_train_inference_python.sh`包含基础链条的4种运行模式,每种模式的运行数据不同,分别用于测试速度和精度,分别是: - 模式1:lite_train_lite_infer,使用少量数据训练,用于快速验证训练到预测的走通流程,不验证精度和速度; ```shell -bash test_tipc/prepare.sh ./test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_windows_gpu_normal_normal_infer_python_windows_cpu_gpu.txt 'lite_train_lite_infer' -bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_windows_gpu_normal_normal_infer_python_windows_cpu_gpu.txt 'lite_train_lite_infer' +bash test_tipc/prepare.sh ./test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_windows_gpu_normal_normal_infer_python_windows_cpu_gpu.txt 'lite_train_lite_infer' +bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_windows_gpu_normal_normal_infer_python_windows_cpu_gpu.txt 'lite_train_lite_infer' ``` - 模式2:lite_train_whole_infer,使用少量数据训练,一定量数据预测,用于验证训练后的模型执行预测,预测速度是否合理; ```shell -bash test_tipc/prepare.sh ./test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_windows_gpu_normal_normal_infer_python_windows_cpu_gpu.txt 'lite_train_whole_infer' -bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_windows_gpu_normal_normal_infer_python_windows_cpu_gpu.txt 'lite_train_whole_infer' +bash test_tipc/prepare.sh ./test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_windows_gpu_normal_normal_infer_python_windows_cpu_gpu.txt 'lite_train_whole_infer' +bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_windows_gpu_normal_normal_infer_python_windows_cpu_gpu.txt 'lite_train_whole_infer' ``` - 模式3:whole_infer,不训练,全量数据预测,走通开源模型评估、动转静,检查inference model预测时间和精度; ```shell -bash test_tipc/prepare.sh ./test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_windows_gpu_normal_normal_infer_python_windows_cpu_gpu.txt 'whole_infer' +bash test_tipc/prepare.sh ./test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_windows_gpu_normal_normal_infer_python_windows_cpu_gpu.txt 'whole_infer' # 用法1: -bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_windows_gpu_normal_normal_infer_python_windows_cpu_gpu.txt 'whole_infer' +bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_windows_gpu_normal_normal_infer_python_windows_cpu_gpu.txt 'whole_infer' # 用法2: 指定GPU卡预测,第三个传入参数为GPU卡号 -bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_windows_gpu_normal_normal_infer_python_windows_cpu_gpu.txt 'whole_infer' '1' +bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_windows_gpu_normal_normal_infer_python_windows_cpu_gpu.txt 'whole_infer' '1' ``` - 模式4:whole_train_whole_infer,CE: 全量数据训练,全量数据预测,验证模型训练精度,预测精度,预测速度; ```shell -bash test_tipc/prepare.sh ./test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_windows_gpu_normal_normal_infer_python_windows_cpu_gpu.txt 'whole_train_whole_infer' -bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_windows_gpu_normal_normal_infer_python_windows_cpu_gpu.txt 'whole_train_whole_infer' +bash test_tipc/prepare.sh ./test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_windows_gpu_normal_normal_infer_python_windows_cpu_gpu.txt 'whole_train_whole_infer' +bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_windows_gpu_normal_normal_infer_python_windows_cpu_gpu.txt 'whole_train_whole_infer' ``` -- 模式5:klquant_whole_infer,测试离线量化; -```shell -bash test_tipc/prepare.sh ./test_tipc/configs/ch_ppocr_mobile_v2.0_det_KL/model_linux_gpu_normal_normal_infer_python_windows_gpu_cpu.txt 'klquant_whole_infer' -bash test_tipc/test_train_inference_python.sh ./test_tipc/configs/ch_ppocr_mobile_v2.0_det_KL/model_linux_gpu_normal_normal_infer_python_windows_gpu_cpu.txt 'klquant_whole_infer' -``` - - 运行相应指令后,在`test_tipc/output`文件夹下自动会保存运行日志。如'lite_train_lite_infer'模式下,会运行训练+inference的链条,因此,在`test_tipc/output`文件夹有以下文件: ``` -test_tipc/output/ +test_tipc/output/model_name/lite_train_lite_infer/ |- results_python.log # 运行指令状态的日志 |- norm_train_gpus_0_autocast_null/ # GPU 0号卡上正常训练的训练日志和模型保存文件夹 -|- pact_train_gpus_0_autocast_null/ # GPU 0号卡上量化训练的训练日志和模型保存文件夹 ...... -|- python_infer_cpu_usemkldnn_True_threads_1_batchsize_1.log # CPU上开启Mkldnn线程数设置为1,测试batch_size=1条件下的预测运行日志 -|- python_infer_gpu_usetrt_True_precision_fp16_batchsize_1.log # GPU上开启TensorRT,测试batch_size=1的半精度预测日志 +|- python_infer_cpu_usemkldnn_False_threads_6_precision_fp32_batchsize_1.log # CPU上关闭Mkldnn线程数设置为6,测试batch_size=1条件下的fp32精度预测运行日志 +|- python_infer_gpu_usetrt_False_precision_fp32_batchsize_1.log # GPU上关闭TensorRT,测试batch_size=1的fp32精度预测日志 ...... ``` diff --git a/test_tipc/prepare.sh b/test_tipc/prepare.sh index bb4b58b4cac900166eeda4d9479fa6bd3fe69e02..688deac0f379b50865fe6739529f9301ebcd919b 100644 --- a/test_tipc/prepare.sh +++ b/test_tipc/prepare.sh @@ -21,7 +21,10 @@ model_name=$(func_parser_value "${lines[1]}") trainer_list=$(func_parser_value "${lines[14]}") if [ ${MODE} = "benchmark_train" ];then - pip install -r requirements.txt + python_name_list=$(func_parser_value "${lines[2]}") + array=(${python_name_list}) + python_name=${array[0]} + ${python_name} -m pip install -r requirements.txt if [[ ${model_name} =~ "ch_ppocr_mobile_v2_0_det" || ${model_name} =~ "det_mv3_db_v2_0" ]];then wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/MobileNetV3_large_x0_5_pretrained.pdparams --no-check-certificate rm -rf ./train_data/icdar2015 @@ -29,6 +32,13 @@ if [ ${MODE} = "benchmark_train" ];then cd ./train_data/ && tar xf icdar2015_benckmark.tar ln -s ./icdar2015_benckmark ./icdar2015 cd ../ + if [[ ${model_name} =~ "ch_ppocr_mobile_v2_0_det" ]];then + # expand gt.txt 2 times + cd ./train_data/icdar2015/text_localization + for i in `seq 2`;do cp train_icdar2015_label.txt dup$i.txt;done + cat dup* > train_icdar2015_label.txt && rm -rf dup* + cd ../../../ + fi fi if [[ ${model_name} =~ "ch_ppocr_server_v2_0_det" || ${model_name} =~ "ch_PP-OCRv3_det" ]];then rm -rf ./train_data/icdar2015 @@ -97,6 +107,15 @@ if [ ${MODE} = "benchmark_train" ];then ln -s ./pubtabnet_benckmark ./pubtabnet cd ../ fi + if [[ ${model_name} == "slanet" ]];then + wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/en_ppstructure_mobile_v2.0_SLANet_train.tar --no-check-certificate + cd ./pretrain_models/ && tar xf en_ppstructure_mobile_v2.0_SLANet_train.tar && cd ../ + rm -rf ./train_data/pubtabnet + wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/pubtabnet_benckmark.tar --no-check-certificate + cd ./train_data/ && tar xf pubtabnet_benckmark.tar + ln -s ./pubtabnet_benckmark ./pubtabnet + cd ../ + fi if [[ ${model_name} == "det_r50_dcn_fce_ctw_v2_0" ]]; then wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/contribution/det_r50_dcn_fce_ctw_v2.0_train.tar --no-check-certificate cd ./pretrain_models/ && tar xf det_r50_dcn_fce_ctw_v2.0_train.tar && cd ../ @@ -107,8 +126,8 @@ if [ ${MODE} = "benchmark_train" ];then cd ../ fi if [ ${model_name} == "layoutxlm_ser" ] || [ ${model_name} == "vi_layoutxlm_ser" ]; then - pip install -r ppstructure/kie/requirements.txt - pip install opencv-python -U + ${python_name} -m pip install -r ppstructure/kie/requirements.txt + ${python_name} -m pip install opencv-python -U wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar --no-check-certificate cd ./train_data/ && tar xf XFUND.tar # expand gt.txt 10 times @@ -122,6 +141,11 @@ if [ ${MODE} = "benchmark_train" ];then fi if [ ${MODE} = "lite_train_lite_infer" ];then + python_name_list=$(func_parser_value "${lines[2]}") + array=(${python_name_list}) + python_name=${array[0]} + ${python_name} -m pip install -r requirements.txt + ${python_name} -m pip install https://paddleocr.bj.bcebos.com/libs/auto_log-1.2.0-py3-none-any.whl # pretrain lite train data wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams --no-check-certificate wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar --no-check-certificate @@ -212,6 +236,10 @@ if [ ${MODE} = "lite_train_lite_infer" ];then if [ ${model_name} == "ch_ppocr_mobile_v2_0_rec_FPGM" ]; then wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_train.tar --no-check-certificate cd ./pretrain_models/ && tar xf ch_ppocr_mobile_v2.0_rec_train.tar && cd ../ + ${python_name} -m pip install paddleslim + fi + if [ ${model_name} == "ch_ppocr_mobile_v2_0_det_FPGM" ]; then + ${python_name} -m pip install paddleslim fi if [ ${model_name} == "det_mv3_east_v2_0" ]; then wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_east_v2.0_train.tar --no-check-certificate @@ -229,13 +257,28 @@ if [ ${MODE} = "lite_train_lite_infer" ];then wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/rec_r32_gaspin_bilstm_att_train.tar --no-check-certificate cd ./pretrain_models/ && tar xf rec_r32_gaspin_bilstm_att_train.tar && cd ../ fi - if [ ${model_name} == "layoutxlm_ser" ] || [ ${model_name} == "vi_layoutxlm_ser" ]; then - pip install -r ppstructure/kie/requirements.txt - pip install opencv-python -U + if [ ${model_name} == "layoutxlm_ser" ]; then + ${python_name} -m pip install -r ppstructure/kie/requirements.txt + ${python_name} -m pip install opencv-python -U + wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar --no-check-certificate + cd ./train_data/ && tar xf XFUND.tar + cd ../ + + wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar --no-check-certificate + cd ./pretrain_models/ && tar xf ser_LayoutXLM_xfun_zh.tar && cd ../ + fi + if [ ${model_name} == "vi_layoutxlm_ser" ]; then + ${python_name} -m pip install -r ppstructure/kie/requirements.txt + ${python_name} -m pip install opencv-python -U wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar --no-check-certificate cd ./train_data/ && tar xf XFUND.tar cd ../ fi + if [ ${model_name} == "det_r18_ct" ]; then + wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet18_vd_pretrained.pdparams --no-check-certificate + wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/ct_tipc/total_text_lite2.tar --no-check-certificate + cd ./train_data && tar xf total_text_lite2.tar && ln -s total_text_lite2 total_text && cd ../ + fi elif [ ${MODE} = "whole_train_whole_infer" ];then wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams --no-check-certificate @@ -304,9 +347,18 @@ elif [ ${MODE} = "lite_train_whole_infer" ];then cd ./inference/ && tar xf en_ppocr_mobile_v2.0_table_det_infer.tar && tar xf en_ppocr_mobile_v2.0_table_rec_infer.tar && cd ../ fi elif [ ${MODE} = "whole_infer" ];then + python_name_list=$(func_parser_value "${lines[2]}") + array=(${python_name_list}) + python_name=${array[0]} + ${python_name} -m pip install paddleslim --force-reinstall + ${python_name} -m pip install -r requirements.txt wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/rec_inference.tar --no-check-certificate cd ./inference && tar xf rec_inference.tar && tar xf ch_det_data_50.tar && cd ../ + wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar --no-check-certificate + cd ./train_data/ && tar xf XFUND.tar && cd ../ + head -n 2 train_data/XFUND/zh_val/val.json > train_data/XFUND/zh_val/val_lite.json + mv train_data/XFUND/zh_val/val_lite.json train_data/XFUND/zh_val/val.json if [ ${model_name} = "ch_ppocr_mobile_v2_0_det" ]; then eval_model_name="ch_ppocr_mobile_v2.0_det_train" rm -rf ./train_data/icdar2015 @@ -467,10 +519,24 @@ elif [ ${MODE} = "whole_infer" ];then cd ./inference/ && tar xf det_r50_dcn_fce_ctw_v2.0_train.tar & cd ../ fi if [[ ${model_name} =~ "en_table_structure" ]];then - wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar --no-check-certificate - cd ./inference/ && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_det_infer.tar && tar xf en_ppocr_mobile_v2.0_table_rec_infer.tar && cd ../ + + cd ./inference/ && tar xf en_ppocr_mobile_v2.0_table_det_infer.tar && tar xf en_ppocr_mobile_v2.0_table_rec_infer.tar + if [ ${model_name} == "en_table_structure" ]; then + wget -nc https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar --no-check-certificate + tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar + elif [ ${model_name} == "en_table_structure_PACT" ]; then + wget -nc https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_slim_infer.tar --no-check-certificate + tar xf en_ppocr_mobile_v2.0_table_structure_slim_infer.tar + fi + cd ../ + fi + if [[ ${model_name} =~ "layoutxlm_ser" ]]; then + ${python_name} -m pip install -r ppstructure/kie/requirements.txt + ${python_name} -m pip install opencv-python -U + wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh_infer.tar --no-check-certificate + cd ./inference/ && tar xf ser_LayoutXLM_xfun_zh_infer.tar & cd ../ fi fi @@ -524,6 +590,12 @@ if [[ ${model_name} =~ "KL" ]]; then cd ./inference/ && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_det_infer.tar && tar xf en_ppocr_mobile_v2.0_table_rec_infer.tar && cd ../ cd ./train_data/ && tar xf pubtabnet.tar && cd ../ fi + if [[ ${model_name} =~ "layoutxlm_ser_KL" ]]; then + wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar --no-check-certificate + cd ./train_data/ && tar xf XFUND.tar && cd ../ + wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh_infer.tar --no-check-certificate + cd ./inference/ && tar xf ser_LayoutXLM_xfun_zh_infer.tar & cd ../ + fi fi if [ ${MODE} = "cpp_infer" ];then @@ -627,7 +699,25 @@ if [ ${MODE} = "cpp_infer" ];then wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar --no-check-certificate cd ./inference && tar xf ch_PP-OCRv3_det_infer.tar && tar xf ch_PP-OCRv3_rec_infer.tar && tar xf ch_det_data_50.tar && cd ../ - fi + elif [[ ${model_name} =~ "en_table_structure" ]];then + wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar --no-check-certificate + wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar --no-check-certificate + + cd ./inference/ && tar xf en_ppocr_mobile_v2.0_table_det_infer.tar && tar xf en_ppocr_mobile_v2.0_table_rec_infer.tar + if [ ${model_name} == "en_table_structure" ]; then + wget -nc https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar --no-check-certificate + tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar + elif [ ${model_name} == "en_table_structure_PACT" ]; then + wget -nc https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_slim_infer.tar --no-check-certificate + tar xf en_ppocr_mobile_v2.0_table_structure_slim_infer.tar + fi + cd ../ + elif [[ ${model_name} =~ "slanet" ]];then + wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar --no-check-certificate + wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar --no-check-certificate + wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar --no-check-certificate + cd ./inference/ && tar xf ch_ppstructure_mobile_v2.0_SLANet_infer.tar && tar xf ch_PP-OCRv3_det_infer.tar && tar xf ch_PP-OCRv3_rec_infer.tar && cd ../ + fi fi if [ ${MODE} = "serving_infer" ];then @@ -639,6 +729,7 @@ if [ ${MODE} = "serving_infer" ];then ${python_name} -m pip install paddle-serving-server-gpu ${python_name} -m pip install paddle_serving_client ${python_name} -m pip install paddle-serving-app + ${python_name} -m pip install https://paddleocr.bj.bcebos.com/libs/auto_log-1.2.0-py3-none-any.whl # wget model if [ ${model_name} == "ch_ppocr_mobile_v2_0_det_KL" ] || [ ${model_name} == "ch_ppocr_mobile_v2.0_rec_KL" ] ; then wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_ppocr_mobile_v2.0_det_klquant_infer.tar --no-check-certificate @@ -690,8 +781,7 @@ fi if [ ${MODE} = "paddle2onnx_infer" ];then # prepare serving env python_name=$(func_parser_value "${lines[2]}") - ${python_name} -m pip install paddle2onnx - ${python_name} -m pip install onnxruntime + ${python_name} -m pip install paddle2onnx onnxruntime onnx # wget model if [[ ${model_name} =~ "ch_ppocr_mobile_v2_0" ]]; then wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar --no-check-certificate @@ -709,6 +799,12 @@ if [ ${MODE} = "paddle2onnx_infer" ];then wget -nc -P ./inference https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar --no-check-certificate cd ./inference && tar xf ch_PP-OCRv3_det_infer.tar && tar xf ch_PP-OCRv3_rec_infer.tar && cd ../ + elif [[ ${model_name} =~ "slanet" ]];then + wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar --no-check-certificate + cd ./inference/ && tar xf ch_ppstructure_mobile_v2.0_SLANet_infer.tar && cd ../ + elif [[ ${model_name} =~ "en_table_structure" ]];then + wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar --no-check-certificate + cd ./inference/ && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar && cd ../ fi # wget data diff --git a/test_tipc/test_paddle2onnx.sh b/test_tipc/test_paddle2onnx.sh index bace6b2d4684e0ad40ffbd76b37a78ddf1e70722..f035e6bb645a1e7927844232c2bff72f0480e38e 100644 --- a/test_tipc/test_paddle2onnx.sh +++ b/test_tipc/test_paddle2onnx.sh @@ -63,7 +63,7 @@ function func_paddle2onnx(){ set_opset_version=$(func_set_params "${opset_version_key}" "${opset_version_value}") set_enable_onnx_checker=$(func_set_params "${enable_onnx_checker_key}" "${enable_onnx_checker_value}") trans_det_log="${LOG_PATH}/trans_model_det.log" - trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker} > ${trans_det_log} 2>&1 " + trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker} --enable_dev_version=False > ${trans_det_log} 2>&1 " eval $trans_model_cmd last_status=${PIPESTATUS[0]} status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" "${trans_det_log}" @@ -75,7 +75,7 @@ function func_paddle2onnx(){ set_opset_version=$(func_set_params "${opset_version_key}" "${opset_version_value}") set_enable_onnx_checker=$(func_set_params "${enable_onnx_checker_key}" "${enable_onnx_checker_value}") trans_rec_log="${LOG_PATH}/trans_model_rec.log" - trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker} > ${trans_rec_log} 2>&1 " + trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker} --enable_dev_version=False > ${trans_rec_log} 2>&1 " eval $trans_model_cmd last_status=${PIPESTATUS[0]} status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" "${trans_rec_log}" @@ -88,7 +88,7 @@ function func_paddle2onnx(){ set_opset_version=$(func_set_params "${opset_version_key}" "${opset_version_value}") set_enable_onnx_checker=$(func_set_params "${enable_onnx_checker_key}" "${enable_onnx_checker_value}") trans_det_log="${LOG_PATH}/trans_model_det.log" - trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker} > ${trans_det_log} 2>&1 " + trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker} --enable_dev_version=False > ${trans_det_log} 2>&1 " eval $trans_model_cmd last_status=${PIPESTATUS[0]} status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" "${trans_det_log}" @@ -101,10 +101,23 @@ function func_paddle2onnx(){ set_opset_version=$(func_set_params "${opset_version_key}" "${opset_version_value}") set_enable_onnx_checker=$(func_set_params "${enable_onnx_checker_key}" "${enable_onnx_checker_value}") trans_rec_log="${LOG_PATH}/trans_model_rec.log" - trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker} > ${trans_rec_log} 2>&1 " + trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker} --enable_dev_version=False > ${trans_rec_log} 2>&1 " eval $trans_model_cmd last_status=${PIPESTATUS[0]} status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" "${trans_rec_log}" + elif [ ${model_name} = "slanet" ] || [ ${model_name} = "en_table_structure" ]; then + # trans det + set_dirname=$(func_set_params "--model_dir" "${det_infer_model_dir_value}") + set_model_filename=$(func_set_params "${model_filename_key}" "${model_filename_value}") + set_params_filename=$(func_set_params "${params_filename_key}" "${params_filename_value}") + set_save_model=$(func_set_params "--save_file" "${det_save_file_value}") + set_opset_version=$(func_set_params "${opset_version_key}" "${opset_version_value}") + set_enable_onnx_checker=$(func_set_params "${enable_onnx_checker_key}" "${enable_onnx_checker_value}") + trans_det_log="${LOG_PATH}/trans_model_det.log" + trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker} --enable_dev_version=True > ${trans_det_log} 2>&1 " + eval $trans_model_cmd + last_status=${PIPESTATUS[0]} + status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" "${trans_det_log}" fi # python inference @@ -117,7 +130,7 @@ function func_paddle2onnx(){ set_det_model_dir=$(func_set_params "${det_model_key}" "${det_save_file_value}") set_rec_model_dir=$(func_set_params "${rec_model_key}" "${rec_save_file_value}") infer_model_cmd="${python} ${inference_py} ${set_gpu} ${set_img_dir} ${set_det_model_dir} ${set_rec_model_dir} --use_onnx=True > ${_save_log_path} 2>&1 " - elif [[ ${model_name} =~ "det" ]]; then + elif [[ ${model_name} =~ "det" ]] || [ ${model_name} = "slanet" ] || [ ${model_name} = "en_table_structure" ]; then set_det_model_dir=$(func_set_params "${det_model_key}" "${det_save_file_value}") infer_model_cmd="${python} ${inference_py} ${set_gpu} ${set_img_dir} ${set_det_model_dir} --use_onnx=True > ${_save_log_path} 2>&1 " elif [[ ${model_name} =~ "rec" ]]; then @@ -136,7 +149,7 @@ function func_paddle2onnx(){ set_det_model_dir=$(func_set_params "${det_model_key}" "${det_save_file_value}") set_rec_model_dir=$(func_set_params "${rec_model_key}" "${rec_save_file_value}") infer_model_cmd="${python} ${inference_py} ${set_gpu} ${set_img_dir} ${set_det_model_dir} ${set_rec_model_dir} --use_onnx=True > ${_save_log_path} 2>&1 " - elif [[ ${model_name} =~ "det" ]]; then + elif [[ ${model_name} =~ "det" ]]|| [ ${model_name} = "slanet" ] || [ ${model_name} = "en_table_structure" ]; then set_det_model_dir=$(func_set_params "${det_model_key}" "${det_save_file_value}") infer_model_cmd="${python} ${inference_py} ${set_gpu} ${set_img_dir} ${set_det_model_dir} --use_onnx=True > ${_save_log_path} 2>&1 " elif [[ ${model_name} =~ "rec" ]]; then diff --git a/test_tipc/test_train_inference_python_npu.sh b/test_tipc/test_train_inference_python_npu.sh new file mode 100644 index 0000000000000000000000000000000000000000..bab70fc78ee902515c0fccb57d9215d86f2a6589 --- /dev/null +++ b/test_tipc/test_train_inference_python_npu.sh @@ -0,0 +1,52 @@ +#!/bin/bash +source test_tipc/common_func.sh + +function readlinkf() { + perl -MCwd -e 'print Cwd::abs_path shift' "$1"; +} + +function func_parser_config() { + strs=$1 + IFS=" " + array=(${strs}) + tmp=${array[2]} + echo ${tmp} +} + +BASEDIR=$(dirname "$0") +REPO_ROOT_PATH=$(readlinkf ${BASEDIR}/../) + +FILENAME=$1 + +# disable mkldnn on non x86_64 env +arch=$(uname -i) +if [ $arch != 'x86_64' ]; then + sed -i 's/--enable_mkldnn:True|False/--enable_mkldnn:False/g' $FILENAME + sed -i 's/--enable_mkldnn:True/--enable_mkldnn:False/g' $FILENAME +fi + +# change gpu to npu in tipc txt configs +sed -i 's/use_gpu/use_npu/g' $FILENAME +# disable benchmark as AutoLog required nvidia-smi command +sed -i 's/--benchmark:True/--benchmark:False/g' $FILENAME +dataline=`cat $FILENAME` + +# parser params +IFS=$'\n' +lines=(${dataline}) + +# replace training config file +grep -n 'tools/.*yml' $FILENAME | cut -d ":" -f 1 \ +| while read line_num ; do + train_cmd=$(func_parser_value "${lines[line_num-1]}") + trainer_config=$(func_parser_config ${train_cmd}) + sed -i 's/use_gpu/use_npu/g' "$REPO_ROOT_PATH/$trainer_config" +done + +# change gpu to npu in execution script +sed -i 's/\"gpu\"/\"npu\"/g' test_tipc/test_train_inference_python.sh + +# pass parameters to test_train_inference_python.sh +cmd='bash test_tipc/test_train_inference_python.sh ${FILENAME} $2' +echo -e '\033[1;32m Started to run command: ${cmd}! \033[0m' +eval $cmd diff --git a/test_tipc/test_train_inference_python_xpu.sh b/test_tipc/test_train_inference_python_xpu.sh new file mode 100644 index 0000000000000000000000000000000000000000..7c6dc1e52a67caf9c858b2f8b6561b3919134b0b --- /dev/null +++ b/test_tipc/test_train_inference_python_xpu.sh @@ -0,0 +1,52 @@ +#!/bin/bash +source test_tipc/common_func.sh + +function readlinkf() { + perl -MCwd -e 'print Cwd::abs_path shift' "$1"; +} + +function func_parser_config() { + strs=$1 + IFS=" " + array=(${strs}) + tmp=${array[2]} + echo ${tmp} +} + +BASEDIR=$(dirname "$0") +REPO_ROOT_PATH=$(readlinkf ${BASEDIR}/../) + +FILENAME=$1 + +# disable mkldnn on non x86_64 env +arch=$(uname -i) +if [ $arch != 'x86_64' ]; then + sed -i 's/--enable_mkldnn:True|False/--enable_mkldnn:False/g' $FILENAME + sed -i 's/--enable_mkldnn:True/--enable_mkldnn:False/g' $FILENAME +fi + +# change gpu to xpu in tipc txt configs +sed -i 's/use_gpu/use_xpu/g' $FILENAME +# disable benchmark as AutoLog required nvidia-smi command +sed -i 's/--benchmark:True/--benchmark:False/g' $FILENAME +dataline=`cat $FILENAME` + +# parser params +IFS=$'\n' +lines=(${dataline}) + +# replace training config file +grep -n 'tools/.*yml' $FILENAME | cut -d ":" -f 1 \ +| while read line_num ; do + train_cmd=$(func_parser_value "${lines[line_num-1]}") + trainer_config=$(func_parser_config ${train_cmd}) + sed -i 's/use_gpu/use_xpu/g' "$REPO_ROOT_PATH/$trainer_config" +done + +# change gpu to xpu in execution script +sed -i 's/\"gpu\"/\"xpu\"/g' test_tipc/test_train_inference_python.sh + +# pass parameters to test_train_inference_python.sh +cmd='bash test_tipc/test_train_inference_python.sh ${FILENAME} $2' +echo -e '\033[1;32m Started to run command: ${cmd}! \033[0m' +eval $cmd diff --git a/tools/export_model.py b/tools/export_model.py index 193988cc1b62a6c4536a8d2ec640e3e5fc81a79c..8610df83ef08926c245872e711cd1c828eb46765 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -115,16 +115,12 @@ def export_single_model(model, max_text_length = arch_config["Head"]["max_text_length"] other_shape = [ paddle.static.InputSpec( - shape=[None, 3, 48, 160], dtype="float32"), - - [ - paddle.static.InputSpec( - shape=[None, ], - dtype="float32"), - paddle.static.InputSpec( - shape=[None, max_text_length], - dtype="int64") - ] + shape=[None, 3, 48, 160], dtype="float32"), [ + paddle.static.InputSpec( + shape=[None, ], dtype="float32"), + paddle.static.InputSpec( + shape=[None, max_text_length], dtype="int64") + ] ] model = to_static(model, input_spec=other_shape) elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]: @@ -140,6 +136,13 @@ def export_single_model(model, paddle.static.InputSpec( shape=[None, 3, 224, 224], dtype="int64"), # image ] + if 'Re' in arch_config['Backbone']['name']: + input_spec.extend([ + paddle.static.InputSpec( + shape=[None, 512, 3], dtype="int64"), # entities + paddle.static.InputSpec( + shape=[None, None, 2], dtype="int64"), # relations + ]) if model.backbone.use_visual_backbone is False: input_spec.pop(4) model = to_static(model, input_spec=[input_spec]) diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index 9f5c480d3c55367a02eacb48bed6ae3d38282f05..52c225d2b3913cf8c0dc88abcc07f7ccfd3cc914 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -127,6 +127,9 @@ class TextDetector(object): postprocess_params["beta"] = args.beta postprocess_params["fourier_degree"] = args.fourier_degree postprocess_params["box_type"] = args.det_fce_box_type + elif self.det_algorithm == "CT": + pre_process_list[0] = {'ScaleAlignedShort': {'short_size': 640}} + postprocess_params['name'] = 'CTPostProcess' else: logger.info("unknown det_algorithm:{}".format(self.det_algorithm)) sys.exit(0) @@ -253,6 +256,9 @@ class TextDetector(object): elif self.det_algorithm == 'FCE': for i, output in enumerate(outputs): preds['level_{}'.format(i)] = output + elif self.det_algorithm == "CT": + preds['maps'] = outputs[0] + preds['score'] = outputs[1] else: raise NotImplementedError @@ -260,7 +266,7 @@ class TextDetector(object): post_result = self.postprocess_op(preds, shape_list) dt_boxes = post_result[0]['points'] if (self.det_algorithm == "SAST" and self.det_sast_polygon) or ( - self.det_algorithm in ["PSE", "FCE"] and + self.det_algorithm in ["PSE", "FCE", "CT"] and self.postprocess_op.box_type == 'poly'): dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape) else: @@ -276,44 +282,67 @@ if __name__ == "__main__": args = utility.parse_args() image_file_list = get_image_file_list(args.image_dir) text_detector = TextDetector(args) - count = 0 total_time = 0 - draw_img_save = "./inference_results" + draw_img_save_dir = args.draw_img_save_dir + os.makedirs(draw_img_save_dir, exist_ok=True) if args.warmup: img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8) for i in range(2): res = text_detector(img) - if not os.path.exists(draw_img_save): - os.makedirs(draw_img_save) save_results = [] - for image_file in image_file_list: - img, flag, _ = check_and_read(image_file) - if not flag: + for idx, image_file in enumerate(image_file_list): + img, flag_gif, flag_pdf = check_and_read(image_file) + if not flag_gif and not flag_pdf: img = cv2.imread(image_file) - if img is None: - logger.info("error in loading image:{}".format(image_file)) - continue - st = time.time() - dt_boxes, _ = text_detector(img) - elapse = time.time() - st - if count > 0: + if not flag_pdf: + if img is None: + logger.debug("error in loading image:{}".format(image_file)) + continue + imgs = [img] + else: + page_num = args.page_num + if page_num > len(img) or page_num == 0: + page_num = len(img) + imgs = img[:page_num] + for index, img in enumerate(imgs): + st = time.time() + dt_boxes, _ = text_detector(img) + elapse = time.time() - st total_time += elapse - count += 1 - save_pred = os.path.basename(image_file) + "\t" + str( - json.dumps([x.tolist() for x in dt_boxes])) + "\n" - save_results.append(save_pred) - logger.info(save_pred) - logger.info("The predict time of {}: {}".format(image_file, elapse)) - src_im = utility.draw_text_det_res(dt_boxes, image_file) - img_name_pure = os.path.split(image_file)[-1] - img_path = os.path.join(draw_img_save, - "det_res_{}".format(img_name_pure)) - cv2.imwrite(img_path, src_im) - logger.info("The visualized image saved in {}".format(img_path)) + if len(imgs) > 1: + save_pred = os.path.basename(image_file) + '_' + str( + index) + "\t" + str( + json.dumps([x.tolist() for x in dt_boxes])) + "\n" + else: + save_pred = os.path.basename(image_file) + "\t" + str( + json.dumps([x.tolist() for x in dt_boxes])) + "\n" + save_results.append(save_pred) + logger.info(save_pred) + if len(imgs) > 1: + logger.info("{}_{} The predict time of {}: {}".format( + idx, index, image_file, elapse)) + else: + logger.info("{} The predict time of {}: {}".format( + idx, image_file, elapse)) + + src_im = utility.draw_text_det_res(dt_boxes, img) + + if flag_gif: + save_file = image_file[:-3] + "png" + elif flag_pdf: + save_file = image_file.replace('.pdf', + '_' + str(index) + '.png') + else: + save_file = image_file + img_path = os.path.join( + draw_img_save_dir, + "det_res_{}".format(os.path.basename(save_file))) + cv2.imwrite(img_path, src_im) + logger.info("The visualized image saved in {}".format(img_path)) - with open(os.path.join(draw_img_save, "det_results.txt"), 'w') as f: + with open(os.path.join(draw_img_save_dir, "det_results.txt"), 'w') as f: f.writelines(save_results) f.close() if args.benchmark: diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py index e0f2c41fa2aba23491efee920afbd76db1ec84e0..affd0d1bcd1283be02ead3cd61c01c375b49bdf9 100755 --- a/tools/infer/predict_system.py +++ b/tools/infer/predict_system.py @@ -159,50 +159,75 @@ def main(args): count = 0 for idx, image_file in enumerate(image_file_list): - img, flag, _ = check_and_read(image_file) - if not flag: + img, flag_gif, flag_pdf = check_and_read(image_file) + if not flag_gif and not flag_pdf: img = cv2.imread(image_file) - if img is None: - logger.debug("error in loading image:{}".format(image_file)) - continue - starttime = time.time() - dt_boxes, rec_res, time_dict = text_sys(img) - elapse = time.time() - starttime - total_time += elapse - - logger.debug( - str(idx) + " Predict time of %s: %.3fs" % (image_file, elapse)) - for text, score in rec_res: - logger.debug("{}, {:.3f}".format(text, score)) - - res = [{ - "transcription": rec_res[idx][0], - "points": np.array(dt_boxes[idx]).astype(np.int32).tolist(), - } for idx in range(len(dt_boxes))] - save_pred = os.path.basename(image_file) + "\t" + json.dumps( - res, ensure_ascii=False) + "\n" - save_results.append(save_pred) - - if is_visualize: - image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) - boxes = dt_boxes - txts = [rec_res[i][0] for i in range(len(rec_res))] - scores = [rec_res[i][1] for i in range(len(rec_res))] - - draw_img = draw_ocr_box_txt( - image, - boxes, - txts, - scores, - drop_score=drop_score, - font_path=font_path) - if flag: - image_file = image_file[:-3] + "png" - cv2.imwrite( - os.path.join(draw_img_save_dir, os.path.basename(image_file)), - draw_img[:, :, ::-1]) - logger.debug("The visualized image saved in {}".format( - os.path.join(draw_img_save_dir, os.path.basename(image_file)))) + if not flag_pdf: + if img is None: + logger.debug("error in loading image:{}".format(image_file)) + continue + imgs = [img] + else: + page_num = args.page_num + if page_num > len(img) or page_num == 0: + page_num = len(img) + imgs = img[:page_num] + for index, img in enumerate(imgs): + starttime = time.time() + dt_boxes, rec_res, time_dict = text_sys(img) + elapse = time.time() - starttime + total_time += elapse + if len(imgs) > 1: + logger.debug( + str(idx) + '_' + str(index) + " Predict time of %s: %.3fs" + % (image_file, elapse)) + else: + logger.debug( + str(idx) + " Predict time of %s: %.3fs" % (image_file, + elapse)) + for text, score in rec_res: + logger.debug("{}, {:.3f}".format(text, score)) + + res = [{ + "transcription": rec_res[i][0], + "points": np.array(dt_boxes[i]).astype(np.int32).tolist(), + } for i in range(len(dt_boxes))] + if len(imgs) > 1: + save_pred = os.path.basename(image_file) + '_' + str( + index) + "\t" + json.dumps( + res, ensure_ascii=False) + "\n" + else: + save_pred = os.path.basename(image_file) + "\t" + json.dumps( + res, ensure_ascii=False) + "\n" + save_results.append(save_pred) + + if is_visualize: + image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) + boxes = dt_boxes + txts = [rec_res[i][0] for i in range(len(rec_res))] + scores = [rec_res[i][1] for i in range(len(rec_res))] + + draw_img = draw_ocr_box_txt( + image, + boxes, + txts, + scores, + drop_score=drop_score, + font_path=font_path) + if flag_gif: + save_file = image_file[:-3] + "png" + elif flag_pdf: + save_file = image_file.replace('.pdf', + '_' + str(index) + '.png') + else: + save_file = image_file + cv2.imwrite( + os.path.join(draw_img_save_dir, + os.path.basename(save_file)), + draw_img[:, :, ::-1]) + logger.debug("The visualized image saved in {}".format( + os.path.join(draw_img_save_dir, os.path.basename( + save_file)))) logger.info("The predict total time is {}".format(time.time() - _st)) if args.benchmark: diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 9baf66d7f469a3bf6c9a140e034aee3a635a5c8e..e555dbec1b314510aaaf6b31f1b35bf60fefa98e 100644 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -23,6 +23,7 @@ from PIL import Image, ImageDraw, ImageFont import math from paddle import inference import time +import random from ppocr.utils.logging import get_logger @@ -35,15 +36,16 @@ def init_args(): # params for prediction engine parser.add_argument("--use_gpu", type=str2bool, default=True) parser.add_argument("--use_xpu", type=str2bool, default=False) + parser.add_argument("--use_npu", type=str2bool, default=False) parser.add_argument("--ir_optim", type=str2bool, default=True) parser.add_argument("--use_tensorrt", type=str2bool, default=False) parser.add_argument("--min_subgraph_size", type=int, default=15) - parser.add_argument("--shape_info_filename", type=str, default=None) parser.add_argument("--precision", type=str, default="fp32") parser.add_argument("--gpu_mem", type=int, default=500) # params for text detector parser.add_argument("--image_dir", type=str) + parser.add_argument("--page_num", type=int, default=0) parser.add_argument("--det_algorithm", type=str, default='DB') parser.add_argument("--det_model_dir", type=str) parser.add_argument("--det_limit_side_len", type=float, default=960) @@ -161,6 +163,8 @@ def create_predictor(args, mode, logger): model_dir = args.table_model_dir elif mode == 'ser': model_dir = args.ser_model_dir + elif mode == 're': + model_dir = args.re_model_dir elif mode == "sr": model_dir = args.sr_model_dir elif mode == 'layout': @@ -226,24 +230,22 @@ def create_predictor(args, mode, logger): use_calib_mode=False) # collect shape - if args.shape_info_filename is not None: - if not os.path.exists(args.shape_info_filename): - config.collect_shape_range_info( - args.shape_info_filename) - logger.info( - f"collect dynamic shape info into : {args.shape_info_filename}" - ) - else: - logger.info( - f"dynamic shape info file( {args.shape_info_filename} ) already exists, not need to generate again." - ) - config.enable_tuned_tensorrt_dynamic_shape( - args.shape_info_filename, True) - else: - logger.info( - f"when using tensorrt, dynamic shape is a suggested option, you can use '--shape_info_filename=shape.txt' for offline dygnamic shape tuning" - ) + trt_shape_f = os.path.join(model_dir, + f"{mode}_trt_dynamic_shape.txt") + if not os.path.exists(trt_shape_f): + config.collect_shape_range_info(trt_shape_f) + logger.info( + f"collect dynamic shape info into : {trt_shape_f}") + try: + config.enable_tuned_tensorrt_dynamic_shape(trt_shape_f, + True) + except Exception as E: + logger.info(E) + logger.info("Please keep your paddlepaddle-gpu >= 2.3.0!") + + elif args.use_npu: + config.enable_npu() elif args.use_xpu: config.enable_xpu(10 * 1024 * 1024) else: @@ -264,6 +266,8 @@ def create_predictor(args, mode, logger): config.disable_glog_info() config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") config.delete_pass("matmul_transpose_reshape_fuse_pass") + if mode == 're': + config.delete_pass("simplify_with_basic_ops_pass") if mode == 'table': config.delete_pass("fc_fuse_pass") # not supported for table config.switch_use_feed_fetch_ops(False) @@ -334,12 +338,11 @@ def draw_e2e_res(dt_boxes, strs, img_path): return src_im -def draw_text_det_res(dt_boxes, img_path): - src_im = cv2.imread(img_path) +def draw_text_det_res(dt_boxes, img): for box in dt_boxes: box = np.array(box).astype(np.int32).reshape(-1, 2) - cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2) - return src_im + cv2.polylines(img, [box], True, color=(255, 255, 0), thickness=2) + return img def resize_img(img, input_size=600): @@ -397,56 +400,81 @@ def draw_ocr(image, def draw_ocr_box_txt(image, boxes, - txts, + txts=None, scores=None, drop_score=0.5, - font_path="./doc/simfang.ttf"): + font_path="./doc/fonts/simfang.ttf"): h, w = image.height, image.width img_left = image.copy() - img_right = Image.new('RGB', (w, h), (255, 255, 255)) - - import random - + img_right = np.ones((h, w, 3), dtype=np.uint8) * 255 random.seed(0) + draw_left = ImageDraw.Draw(img_left) - draw_right = ImageDraw.Draw(img_right) + if txts is None or len(txts) != len(boxes): + txts = [None] * len(boxes) for idx, (box, txt) in enumerate(zip(boxes, txts)): if scores is not None and scores[idx] < drop_score: continue color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) draw_left.polygon(box, fill=color) - draw_right.polygon( - [ - box[0][0], box[0][1], box[1][0], box[1][1], box[2][0], - box[2][1], box[3][0], box[3][1] - ], - outline=color) - box_height = math.sqrt((box[0][0] - box[3][0])**2 + (box[0][1] - box[3][ - 1])**2) - box_width = math.sqrt((box[0][0] - box[1][0])**2 + (box[0][1] - box[1][ - 1])**2) - if box_height > 2 * box_width: - font_size = max(int(box_width * 0.9), 10) - font = ImageFont.truetype(font_path, font_size, encoding="utf-8") - cur_y = box[0][1] - for c in txt: - char_size = font.getsize(c) - draw_right.text( - (box[0][0] + 3, cur_y), c, fill=(0, 0, 0), font=font) - cur_y += char_size[1] - else: - font_size = max(int(box_height * 0.8), 10) - font = ImageFont.truetype(font_path, font_size, encoding="utf-8") - draw_right.text( - [box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font) + img_right_text = draw_box_txt_fine((w, h), box, txt, font_path) + pts = np.array(box, np.int32).reshape((-1, 1, 2)) + cv2.polylines(img_right_text, [pts], True, color, 1) + img_right = cv2.bitwise_and(img_right, img_right_text) img_left = Image.blend(image, img_left, 0.5) img_show = Image.new('RGB', (w * 2, h), (255, 255, 255)) img_show.paste(img_left, (0, 0, w, h)) - img_show.paste(img_right, (w, 0, w * 2, h)) + img_show.paste(Image.fromarray(img_right), (w, 0, w * 2, h)) return np.array(img_show) +def draw_box_txt_fine(img_size, box, txt, font_path="./doc/fonts/simfang.ttf"): + box_height = int( + math.sqrt((box[0][0] - box[3][0])**2 + (box[0][1] - box[3][1])**2)) + box_width = int( + math.sqrt((box[0][0] - box[1][0])**2 + (box[0][1] - box[1][1])**2)) + + if box_height > 2 * box_width and box_height > 30: + img_text = Image.new('RGB', (box_height, box_width), (255, 255, 255)) + draw_text = ImageDraw.Draw(img_text) + if txt: + font = create_font(txt, (box_height, box_width), font_path) + draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font) + img_text = img_text.transpose(Image.ROTATE_270) + else: + img_text = Image.new('RGB', (box_width, box_height), (255, 255, 255)) + draw_text = ImageDraw.Draw(img_text) + if txt: + font = create_font(txt, (box_width, box_height), font_path) + draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font) + + pts1 = np.float32( + [[0, 0], [box_width, 0], [box_width, box_height], [0, box_height]]) + pts2 = np.array(box, dtype=np.float32) + M = cv2.getPerspectiveTransform(pts1, pts2) + + img_text = np.array(img_text, dtype=np.uint8) + img_right_text = cv2.warpPerspective( + img_text, + M, + img_size, + flags=cv2.INTER_NEAREST, + borderMode=cv2.BORDER_CONSTANT, + borderValue=(255, 255, 255)) + return img_right_text + + +def create_font(txt, sz, font_path="./doc/fonts/simfang.ttf"): + font_size = int(sz[1] * 0.99) + font = ImageFont.truetype(font_path, font_size, encoding="utf-8") + length = font.getsize(txt)[0] + if length > sz[0]: + font_size = int(font_size * sz[0] / length) + font = ImageFont.truetype(font_path, font_size, encoding="utf-8") + return font + + def str_count(s): """ Count the number of Chinese characters, diff --git a/tools/infer_e2e.py b/tools/infer_e2e.py index d3e6b28fca0a3ff32ea940747712d6c71aa290fd..37fdcbaadc2984c9cf4fb105b7122db31b99be30 100755 --- a/tools/infer_e2e.py +++ b/tools/infer_e2e.py @@ -37,6 +37,46 @@ from ppocr.postprocess import build_post_process from ppocr.utils.save_load import load_model from ppocr.utils.utility import get_image_file_list import tools.program as program +from PIL import Image, ImageDraw, ImageFont +import math + + +def draw_e2e_res_for_chinese(image, + boxes, + txts, + config, + img_name, + font_path="./doc/simfang.ttf"): + h, w = image.height, image.width + img_left = image.copy() + img_right = Image.new('RGB', (w, h), (255, 255, 255)) + + import random + + random.seed(0) + draw_left = ImageDraw.Draw(img_left) + draw_right = ImageDraw.Draw(img_right) + for idx, (box, txt) in enumerate(zip(boxes, txts)): + box = np.array(box) + box = [tuple(x) for x in box] + color = (random.randint(0, 255), random.randint(0, 255), + random.randint(0, 255)) + draw_left.polygon(box, fill=color) + draw_right.polygon(box, outline=color) + font = ImageFont.truetype(font_path, 15, encoding="utf-8") + draw_right.text([box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font) + img_left = Image.blend(image, img_left, 0.5) + img_show = Image.new('RGB', (w * 2, h), (255, 255, 255)) + img_show.paste(img_left, (0, 0, w, h)) + img_show.paste(img_right, (w, 0, w * 2, h)) + + save_e2e_path = os.path.dirname(config['Global'][ + 'save_res_path']) + "/e2e_results/" + if not os.path.exists(save_e2e_path): + os.makedirs(save_e2e_path) + save_path = os.path.join(save_e2e_path, os.path.basename(img_name)) + cv2.imwrite(save_path, np.array(img_show)[:, :, ::-1]) + logger.info("The e2e Image saved in {}".format(save_path)) def draw_e2e_res(dt_boxes, strs, config, img, img_name): @@ -113,7 +153,19 @@ def main(): otstr = file + "\t" + json.dumps(dt_boxes_json) + "\n" fout.write(otstr.encode()) src_img = cv2.imread(file) - draw_e2e_res(points, strs, config, src_img, file) + if global_config['infer_visual_type'] == 'EN': + draw_e2e_res(points, strs, config, src_img, file) + elif global_config['infer_visual_type'] == 'CN': + src_img = Image.fromarray( + cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)) + draw_e2e_res_for_chinese( + src_img, + points, + strs, + config, + file, + font_path="./doc/fonts/simfang.ttf") + logger.info("success!") diff --git a/tools/infer_kie_token_ser_re.py b/tools/infer_kie_token_ser_re.py index 3ee696f28470a16205be628b3aeb586ef7a9c6a6..c4fa2c927ab93cfa9082e51f08f8d6e1c35fe29e 100755 --- a/tools/infer_kie_token_ser_re.py +++ b/tools/infer_kie_token_ser_re.py @@ -63,7 +63,7 @@ class ReArgsParser(ArgsParser): def make_input(ser_inputs, ser_results): entities_labels = {'HEADER': 0, 'QUESTION': 1, 'ANSWER': 2} - + batch_size, max_seq_len = ser_inputs[0].shape[:2] entities = ser_inputs[8][0] ser_results = ser_results[0] assert len(entities) == len(ser_results) @@ -80,34 +80,44 @@ def make_input(ser_inputs, ser_results): start.append(entity['start']) end.append(entity['end']) label.append(entities_labels[res['pred']]) - entities = dict(start=start, end=end, label=label) + + entities = np.full([max_seq_len + 1, 3], fill_value=-1) + entities[0, 0] = len(start) + entities[1:len(start) + 1, 0] = start + entities[0, 1] = len(end) + entities[1:len(end) + 1, 1] = end + entities[0, 2] = len(label) + entities[1:len(label) + 1, 2] = label # relations head = [] tail = [] - for i in range(len(entities["label"])): - for j in range(len(entities["label"])): - if entities["label"][i] == 1 and entities["label"][j] == 2: + for i in range(len(label)): + for j in range(len(label)): + if label[i] == 1 and label[j] == 2: head.append(i) tail.append(j) - relations = dict(head=head, tail=tail) + relations = np.full([len(head) + 1, 2], fill_value=-1) + relations[0, 0] = len(head) + relations[1:len(head) + 1, 0] = head + relations[0, 1] = len(tail) + relations[1:len(tail) + 1, 1] = tail + + entities = np.expand_dims(entities, axis=0) + entities = np.repeat(entities, batch_size, axis=0) + relations = np.expand_dims(relations, axis=0) + relations = np.repeat(relations, batch_size, axis=0) + + # remove ocr_info segment_offset_id and label in ser input + if isinstance(ser_inputs[0], paddle.Tensor): + entities = paddle.to_tensor(entities) + relations = paddle.to_tensor(relations) + ser_inputs = ser_inputs[:5] + [entities, relations] - batch_size = ser_inputs[0].shape[0] - entities_batch = [] - relations_batch = [] entity_idx_dict_batch = [] for b in range(batch_size): - entities_batch.append(entities) - relations_batch.append(relations) entity_idx_dict_batch.append(entity_idx_dict) - - ser_inputs[8] = entities_batch - ser_inputs.append(relations_batch) - # remove ocr_info segment_offset_id and label in ser input - ser_inputs.pop(7) - ser_inputs.pop(6) - ser_inputs.pop(5) return ser_inputs, entity_idx_dict_batch @@ -136,6 +146,8 @@ class SerRePredictor(object): def __call__(self, data): ser_results, ser_inputs = self.ser_engine(data) re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results) + if self.model.backbone.use_visual_backbone is False: + re_input.pop(4) preds = self.model(re_input) post_result = self.post_process_class( preds, diff --git a/tools/program.py b/tools/program.py index 16d3d4035af933cda01b422ea56e9e2895ec2b88..9117d51b95b343c46982f212d4e5faa069b7b44a 100755 --- a/tools/program.py +++ b/tools/program.py @@ -114,7 +114,7 @@ def merge_config(config, opts): return config -def check_device(use_gpu, use_xpu=False): +def check_device(use_gpu, use_xpu=False, use_npu=False): """ Log error and exit when set use_gpu=true in paddlepaddle cpu version. @@ -134,24 +134,8 @@ def check_device(use_gpu, use_xpu=False): if use_xpu and not paddle.device.is_compiled_with_xpu(): print(err.format("use_xpu", "xpu", "xpu", "use_xpu")) sys.exit(1) - except Exception as e: - pass - - -def check_xpu(use_xpu): - """ - Log error and exit when set use_xpu=true in paddlepaddle - cpu/gpu version. - """ - err = "Config use_xpu cannot be set as true while you are " \ - "using paddlepaddle cpu/gpu version ! \nPlease try: \n" \ - "\t1. Install paddlepaddle-xpu to run model on XPU \n" \ - "\t2. Set use_xpu as false in config file to run " \ - "model on CPU/GPU" - - try: - if use_xpu and not paddle.is_compiled_with_xpu(): - print(err) + if use_npu and not paddle.device.is_compiled_with_npu(): + print(err.format("use_npu", "npu", "npu", "use_npu")) sys.exit(1) except Exception as e: pass @@ -279,7 +263,9 @@ def train(config, model_average = True # use amp if scaler: - with paddle.amp.auto_cast(level=amp_level, custom_black_list=amp_custom_black_list): + with paddle.amp.auto_cast( + level=amp_level, + custom_black_list=amp_custom_black_list): if model_type == 'table' or extra_input: preds = model(images, data=batch[1:]) elif model_type in ["kie"]: @@ -479,7 +465,7 @@ def eval(model, extra_input=False, scaler=None, amp_level='O2', - amp_custom_black_list = []): + amp_custom_black_list=[]): model.eval() with paddle.no_grad(): total_frame = 0.0 @@ -500,7 +486,9 @@ def eval(model, # use amp if scaler: - with paddle.amp.auto_cast(level=amp_level, custom_black_list=amp_custom_black_list): + with paddle.amp.auto_cast( + level=amp_level, + custom_black_list=amp_custom_black_list): if model_type == 'table' or extra_input: preds = model(images, data=batch[1:]) elif model_type in ["kie"]: @@ -627,14 +615,9 @@ def preprocess(is_train=False): logger = get_logger(log_file=log_file) # check if set use_gpu=True in paddlepaddle cpu version - use_gpu = config['Global']['use_gpu'] + use_gpu = config['Global'].get('use_gpu', False) use_xpu = config['Global'].get('use_xpu', False) - - # check if set use_xpu=True in paddlepaddle cpu/gpu version - use_xpu = False - if 'use_xpu' in config['Global']: - use_xpu = config['Global']['use_xpu'] - check_xpu(use_xpu) + use_npu = config['Global'].get('use_npu', False) alg = config['Architecture']['algorithm'] assert alg in [ @@ -642,15 +625,17 @@ def preprocess(is_train=False): 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE', 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE', 'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN', - 'Gestalt', 'SLANet', 'RobustScanner' + 'Gestalt', 'SLANet', 'RobustScanner', 'CT' ] if use_xpu: device = 'xpu:{0}'.format(os.getenv('FLAGS_selected_xpus', 0)) + elif use_npu: + device = 'npu:{0}'.format(os.getenv('FLAGS_selected_npus', 0)) else: device = 'gpu:{}'.format(dist.ParallelEnv() .dev_id) if use_gpu else 'cpu' - check_device(use_gpu, use_xpu) + check_device(use_gpu, use_xpu, use_npu) device = paddle.set_device(device) diff --git a/tools/train.py b/tools/train.py index d0f200189e34265b3c080ac9e25eb80d29c705b7..970a52624af7b2831d88956f857cd4271086bcca 100755 --- a/tools/train.py +++ b/tools/train.py @@ -119,6 +119,7 @@ def main(config, device, logger, vdl_writer): config['Loss']['ignore_index'] = char_num - 1 model = build_model(config['Architecture']) + use_sync_bn = config["Global"].get("use_sync_bn", False) if use_sync_bn: model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model) @@ -138,7 +139,7 @@ def main(config, device, logger, vdl_writer): # build metric eval_class = build_metric(config['Metric']) - + logger.info('train dataloader has {} iters'.format(len(train_dataloader))) if valid_dataloader is not None: logger.info('valid dataloader has {} iters'.format( @@ -146,7 +147,7 @@ def main(config, device, logger, vdl_writer): use_amp = config["Global"].get("use_amp", False) amp_level = config["Global"].get("amp_level", 'O2') - amp_custom_black_list = config['Global'].get('amp_custom_black_list',[]) + amp_custom_black_list = config['Global'].get('amp_custom_black_list', []) if use_amp: AMP_RELATED_FLAGS_SETTING = { 'FLAGS_cudnn_batchnorm_spatial_persistent': 1, @@ -161,20 +162,24 @@ def main(config, device, logger, vdl_writer): use_dynamic_loss_scaling=use_dynamic_loss_scaling) if amp_level == "O2": model, optimizer = paddle.amp.decorate( - models=model, optimizers=optimizer, level=amp_level, master_weight=True) + models=model, + optimizers=optimizer, + level=amp_level, + master_weight=True) else: scaler = None # load pretrain model pre_best_model_dict = load_model(config, model, optimizer, config['Architecture']["model_type"]) - + if config['Global']['distributed']: model = paddle.DataParallel(model) # start train program.train(config, train_dataloader, valid_dataloader, device, model, loss_class, optimizer, lr_scheduler, post_process_class, - eval_class, pre_best_model_dict, logger, vdl_writer, scaler,amp_level, amp_custom_black_list) + eval_class, pre_best_model_dict, logger, vdl_writer, scaler, + amp_level, amp_custom_black_list) def test_reader(config, device, logger): diff --git a/train.sh b/train.sh index 4225470cb9f545b874e5f806af22405895e8f6c7..6fa04ea3febe8982016a35d83f119c0a483e3bb8 100644 --- a/train.sh +++ b/train.sh @@ -1,2 +1,2 @@ # recommended paddle.__version__ == 2.0.0 -python3 -m paddle.distributed.launch --log_dir=./debug/ --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml +python3 -m paddle.distributed.launch --log_dir=./debug/ --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml \ No newline at end of file