diff --git a/ppocr/metrics/table_metric.py b/ppocr/metrics/table_metric.py
index 43dc1d7617904bedcdec445d86933a3b8c5850f1..c0b247efa672caacb9a9a09a8ef0da58e47367e4 100644
--- a/ppocr/metrics/table_metric.py
+++ b/ppocr/metrics/table_metric.py
@@ -16,9 +16,14 @@ from ppocr.metrics.det_metric import DetMetric
class TableStructureMetric(object):
- def __init__(self, main_indicator='acc', eps=1e-6, **kwargs):
+ def __init__(self,
+ main_indicator='acc',
+ eps=1e-6,
+ del_thead_tbody=False,
+ **kwargs):
self.main_indicator = main_indicator
self.eps = eps
+ self.del_thead_tbody = del_thead_tbody
self.reset()
def __call__(self, pred_label, batch=None, *args, **kwargs):
@@ -31,6 +36,13 @@ class TableStructureMetric(object):
gt_structure_batch_list):
pred_str = ''.join(pred)
target_str = ''.join(target)
+ if self.del_thead_tbody:
+ pred_str = pred_str.replace('', '').replace(
+ '', '').replace('
', '').replace('',
+ '')
+ target_str = target_str.replace('', '').replace(
+ '', '').replace('', '').replace('',
+ '')
if pred_str == target_str:
correct_num += 1
all_num += 1
@@ -60,6 +72,7 @@ class TableMetric(object):
main_indicator='acc',
compute_bbox_metric=False,
box_format='xyxy',
+ del_thead_tbody=False,
**kwargs):
"""
@@ -67,7 +80,8 @@ class TableMetric(object):
@param main_matric: main_matric for save best_model
@param kwargs:
"""
- self.structure_metric = TableStructureMetric()
+ self.structure_metric = TableStructureMetric(
+ del_thead_tbody=del_thead_tbody)
self.bbox_metric = DetMetric() if compute_bbox_metric else None
self.main_indicator = main_indicator
self.box_format = box_format
diff --git a/ppstructure/docs/quickstart.md b/ppstructure/docs/quickstart.md
index d206d1d5217b3711d3865625d2bb4d2c7b13b423..700800759923ec4f3703bdb542bee0df930d6cd1 100644
--- a/ppstructure/docs/quickstart.md
+++ b/ppstructure/docs/quickstart.md
@@ -3,15 +3,17 @@
- [1. 安装依赖包](#1-安装依赖包)
- [2. 便捷使用](#2-便捷使用)
- [2.1 命令行使用](#21-命令行使用)
+ - [2.1.1 图像方向分类+版面分析+表格识别](#211-图像方向分类版面分析表格识别)
- [2.1.1 版面分析+表格识别](#211-版面分析表格识别)
- - [2.1.2 版面分析](#212-版面分析)
- - [2.1.3 表格识别](#213-表格识别)
- - [2.1.4 DocVQA](#214-docvqa)
+ - [2.1.3 版面分析](#213-版面分析)
+ - [2.1.4 表格识别](#214-表格识别)
+ - [2.1.5 DocVQA](#215-docvqa)
- [2.2 代码使用](#22-代码使用)
- - [2.2.1 版面分析+表格识别](#221-版面分析表格识别)
- - [2.2.2 版面分析](#222-版面分析)
- - [2.2.3 表格识别](#223-表格识别)
- - [2.2.4 DocVQA](#224-docvqa)
+ - [2.2.1 图像方向分类版面分析表格识别](#221-图像方向分类版面分析表格识别)
+ - [2.2.2 版面分析+表格识别](#222-版面分析表格识别)
+ - [2.2.3 版面分析](#223-版面分析)
+ - [2.2.4 表格识别](#224-表格识别)
+ - [2.2.5 DocVQA](#225-docvqa)
- [2.3 返回结果说明](#23-返回结果说明)
- [2.3.1 版面分析+表格识别](#231-版面分析表格识别)
- [2.3.2 DocVQA](#232-docvqa)
@@ -36,25 +38,31 @@ pip install paddlenlp
### 2.1 命令行使用
+#### 2.1.1 图像方向分类+版面分析+表格识别
+```bash
+paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/1.png --type=structure --image_orientation=true
+```
+
+
#### 2.1.1 版面分析+表格识别
```bash
paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/1.png --type=structure
```
-
-#### 2.1.2 版面分析
+
+#### 2.1.3 版面分析
```bash
paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/1.png --type=structure --table=false --ocr=false
```
-
-#### 2.1.3 表格识别
+
+#### 2.1.4 表格识别
```bash
paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/table.jpg --type=structure --layout=false
```
-
-#### 2.1.4 DocVQA
+
+#### 2.1.5 DocVQA
请参考:[文档视觉问答](../vqa/README.md)。
@@ -62,14 +70,14 @@ paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/table.jpg --type=structur
### 2.2 代码使用
-#### 2.2.1 版面分析+表格识别
+#### 2.2.1 图像方向分类版面分析表格识别
```python
import os
import cv2
from paddleocr import PPStructure,draw_structure_result,save_structure_res
-table_engine = PPStructure(show_log=True)
+table_engine = PPStructure(show_log=True, image_orientation=True)
save_folder = './output'
img_path = 'PaddleOCR/ppstructure/docs/table/1.png'
@@ -91,7 +99,36 @@ im_show.save('result.jpg')
```
-#### 2.2.2 版面分析
+#### 2.2.2 版面分析+表格识别
+
+```python
+import os
+import cv2
+from paddleocr import PPStructure,draw_structure_result,save_structure_res
+
+table_engine = PPStructure(show_log=True)
+
+save_folder = './output'
+img_path = 'PaddleOCR/ppstructure/docs/table/1.png'
+img = cv2.imread(img_path)
+result = table_engine(img)
+save_structure_res(result, save_folder,os.path.basename(img_path).split('.')[0])
+
+for line in result:
+ line.pop('img')
+ print(line)
+
+from PIL import Image
+
+font_path = 'PaddleOCR/doc/fonts/simfang.ttf' # PaddleOCR下提供字体包
+image = Image.open(img_path).convert('RGB')
+im_show = draw_structure_result(image, result,font_path=font_path)
+im_show = Image.fromarray(im_show)
+im_show.save('result.jpg')
+```
+
+
+#### 2.2.3 版面分析
```python
import os
@@ -111,8 +148,8 @@ for line in result:
print(line)
```
-
-#### 2.2.3 表格识别
+
+#### 2.2.4 表格识别
```python
import os
@@ -132,8 +169,8 @@ for line in result:
print(line)
```
-
-#### 2.2.4 DocVQA
+
+#### 2.2.5 DocVQA
请参考:[文档视觉问答](../vqa/README.md)。
@@ -154,10 +191,10 @@ PP-Structure的返回结果为一个dict组成的list,示例如下
```
dict 里各个字段说明如下
-| 字段 | 说明 |
-| --------------- |-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
-|type| 图片区域的类型 |
-|bbox| 图片区域的在原图的坐标,分别[左上角x,左上角y,右下角x,右下角y] |
+| 字段 | 说明|
+| --- |---|
+|type| 图片区域的类型 |
+|bbox| 图片区域的在原图的坐标,分别[左上角x,左上角y,右下角x,右下角y]|
|res| 图片区域的OCR或表格识别结果。
表格: 一个dict,字段说明如下
`html`: 表格的HTML字符串
在代码使用模式下,前向传入return_ocr_result_in_table=True可以拿到表格中每个文本的检测识别结果,对应为如下字段:
`boxes`: 文本检测坐标
`rec_res`: 文本识别结果。
OCR: 一个包含各个单行文字的检测坐标和识别结果的元组 |
运行完成后,每张图片会在`output`字段指定的目录下有一个同名目录,图片里的每个表格会存储为一个excel,图片区域会被裁剪之后保存下来,excel文件和图片名为表格在图片里的坐标。
@@ -178,20 +215,26 @@ dict 里各个字段说明如下
### 2.4 参数说明
-| 字段 | 说明 | 默认值 |
-|----------------------|----------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------|
-| output | excel和识别结果保存的地址 | ./output/table |
-| table_max_len | 表格结构模型预测时,图像的长边resize尺度 | 488 |
-| table_model_dir | 表格结构模型 inference 模型地址 | None |
-| table_char_dict_path | 表格结构模型所用字典地址 | ../ppocr/utils/dict/table_structure_dict.txt |
-| layout_path_model | 版面分析模型模型地址,可以为在线地址或者本地地址,当为本地地址时,需要指定 layout_label_map, 命令行模式下可通过--layout_label_map='{0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"}' 指定 | lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config |
-| layout_label_map | 版面分析模型模型label映射字典 | None |
-| model_name_or_path | VQA SER模型地址 | None |
-| max_seq_length | VQA SER模型最大支持token长度 | 512 |
-| label_map_path | VQA SER 标签文件地址 | ./vqa/labels/labels_ser.txt |
-| layout | 前向中是否执行版面分析 | True |
-| table | 前向中是否执行表格识别 | True |
-| ocr | 对于版面分析中的非表格区域,是否执行ocr。当layout为False时会被自动设置为False | True |
-| structure_version | 表格结构化模型版本,可选 PP-STRUCTURE。PP-STRUCTURE支持表格结构化模型 | PP-STRUCTURE |
+| 字段 | 说明 | 默认值 |
+|---|---|---|
+| output | 结果保存地址 | ./output/table |
+| table_max_len | 表格结构模型预测时,图像的长边resize尺度 | 488 |
+| table_model_dir | 表格结构模型 inference 模型地址| None |
+| table_char_dict_path | 表格结构模型所用字典地址 | ../ppocr/utils/dict/table_structure_dict.txt |
+| merge_no_span_structure | 表格识别模型中,是否对'\'和'\ | ' 进行合并 | False |
+| layout_model_dir | 版面分析模型 inference 模型地址 | None |
+| layout_dict_path | 版面分析模型字典| ../ppocr/utils/dict/layout_publaynet_dict.txt |
+| layout_score_threshold | 版面分析模型检测框阈值| 0.5|
+| layout_nms_threshold | 版面分析模型nms阈值| 0.5|
+| vqa_algorithm | vqa模型算法| LayoutXLM|
+| ser_model_dir | ser模型 inference 模型地址| None|
+| ser_dict_path | ser模型字典| ../train_data/XFUND/class_list_xfun.txt|
+| mode | structure or vqa | structure |
+| image_orientation | 前向中是否执行图像方向分类 | False |
+| layout | 前向中是否执行版面分析 | True |
+| table | 前向中是否执行表格识别 | True |
+| ocr | 对于版面分析中的非表格区域,是否执行ocr。当layout为False时会被自动设置为False| True |
+| recovery | 前向中是否执行版面恢复| False |
+| structure_version | 模型版本,可选 PP-structure和PP-structurev2 | PP-structure |
大部分参数和PaddleOCR whl包保持一致,见 [whl包文档](../../doc/doc_ch/whl.md)
diff --git a/ppstructure/docs/quickstart_en.md b/ppstructure/docs/quickstart_en.md
index 98d8d2fc3fa5e7c8cd1670f62da8093d99f2d8fa..b4dee3f02d3c2762ef71720995f4da697ae43622 100644
--- a/ppstructure/docs/quickstart_en.md
+++ b/ppstructure/docs/quickstart_en.md
@@ -3,15 +3,17 @@
- [1. Install package](#1-install-package)
- [2. Use](#2-use)
- [2.1 Use by command line](#21-use-by-command-line)
- - [2.1.1 layout analysis + table recognition](#211-layout-analysis--table-recognition)
- - [2.1.2 layout analysis](#212-layout-analysis)
- - [2.1.3 table recognition](#213-table-recognition)
- - [2.1.4 DocVQA](#214-docvqa)
+ - [2.1.1 image orientation + layout analysis + table recognition](#211-image-orientation--layout-analysis--table-recognition)
+ - [2.1.2 layout analysis + table recognition](#212-layout-analysis--table-recognition)
+ - [2.1.3 layout analysis](#213-layout-analysis)
+ - [2.1.4 table recognition](#214-table-recognition)
+ - [2.1.5 DocVQA](#215-docvqa)
- [2.2 Use by code](#22-use-by-code)
- - [2.2.1 layout analysis + table recognition](#221-layout-analysis--table-recognition)
- - [2.2.2 layout analysis](#222-layout-analysis)
- - [2.2.3 table recognition](#223-table-recognition)
- - [2.2.4 DocVQA](#224-docvqa)
+ - [2.2.1 image orientation + layout analysis + table recognition](#221-image-orientation--layout-analysis--table-recognition)
+ - [2.2.2 layout analysis + table recognition](#222-layout-analysis--table-recognition)
+ - [2.2.3 layout analysis](#223-layout-analysis)
+ - [2.2.4 table recognition](#224-table-recognition)
+ - [2.2.5 DocVQA](#225-docvqa)
- [2.3 Result description](#23-result-description)
- [2.3.1 layout analysis + table recognition](#231-layout-analysis--table-recognition)
- [2.3.2 DocVQA](#232-docvqa)
@@ -36,25 +38,31 @@ pip install paddlenlp
### 2.1 Use by command line
-#### 2.1.1 layout analysis + table recognition
+#### 2.1.1 image orientation + layout analysis + table recognition
```bash
-paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/1.png --type=structure
+paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/1.png --type=structure --image_orientation=true
```
-#### 2.1.2 layout analysis
+#### 2.1.2 layout analysis + table recognition
```bash
-paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/1.png --type=structure --table=false --ocr=false
+paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/1.png --type=structure
```
-#### 2.1.3 table recognition
+#### 2.1.3 layout analysis
```bash
-paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/table.jpg --type=structure --layout=false
+paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/1.png --type=structure --table=false --ocr=false
```
-#### 2.1.4 DocVQA
+#### 2.1.4 table recognition
+```bash
+paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/table.jpg --type=structure --layout=false
+```
+
+
+#### 2.1.5 DocVQA
Please refer to: [Documentation Visual Q&A](../vqa/README.md) .
@@ -62,14 +70,14 @@ Please refer to: [Documentation Visual Q&A](../vqa/README.md) .
### 2.2 Use by code
-#### 2.2.1 layout analysis + table recognition
+#### 2.2.1 image orientation + layout analysis + table recognition
```python
import os
import cv2
from paddleocr import PPStructure,draw_structure_result,save_structure_res
-table_engine = PPStructure(show_log=True)
+table_engine = PPStructure(show_log=True, image_orientation=True)
save_folder = './output'
img_path = 'PaddleOCR/ppstructure/docs/table/1.png'
@@ -91,7 +99,36 @@ im_show.save('result.jpg')
```
-#### 2.2.2 layout analysis
+#### 2.2.2 layout analysis + table recognition
+
+```python
+import os
+import cv2
+from paddleocr import PPStructure,draw_structure_result,save_structure_res
+
+table_engine = PPStructure(show_log=True)
+
+save_folder = './output'
+img_path = 'PaddleOCR/ppstructure/docs/table/1.png'
+img = cv2.imread(img_path)
+result = table_engine(img)
+save_structure_res(result, save_folder,os.path.basename(img_path).split('.')[0])
+
+for line in result:
+ line.pop('img')
+ print(line)
+
+from PIL import Image
+
+font_path = 'PaddleOCR/doc/fonts/simfang.ttf' # PaddleOCR下提供字体包
+image = Image.open(img_path).convert('RGB')
+im_show = draw_structure_result(image, result,font_path=font_path)
+im_show = Image.fromarray(im_show)
+im_show.save('result.jpg')
+```
+
+
+#### 2.2.3 layout analysis
```python
import os
@@ -111,8 +148,8 @@ for line in result:
print(line)
```
-
-#### 2.2.3 table recognition
+
+#### 2.2.4 table recognition
```python
import os
@@ -132,8 +169,8 @@ for line in result:
print(line)
```
-
-#### 2.2.4 DocVQA
+
+#### 2.2.5 DocVQA
Please refer to: [Documentation Visual Q&A](../vqa/README.md) .
@@ -155,8 +192,8 @@ The return of PP-Structure is a list of dicts, the example is as follows:
```
Each field in dict is described as follows:
-| field | description |
-| --------------- |--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| field | description |
+| --- |---|
|type| Type of image area. |
|bbox| The coordinates of the image area in the original image, respectively [upper left corner x, upper left corner y, lower right corner x, lower right corner y]. |
|res| OCR or table recognition result of the image area.
table: a dict with field descriptions as follows:
`html`: html str of table.
In the code usage mode, set return_ocr_result_in_table=True whrn call can get the detection and recognition results of each text in the table area, corresponding to the following fields:
`boxes`: text detection boxes.
`rec_res`: text recognition results.
OCR: A tuple containing the detection boxes and recognition results of each single text. |
@@ -178,19 +215,26 @@ Please refer to: [Documentation Visual Q&A](../vqa/README.md) .
### 2.4 Parameter Description
-| field | description | default |
-|----------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------|
-| output | The save path of result | ./output/table |
-| table_max_len | When the table structure model predicts, the long side of the image | 488 |
-| table_model_dir | the path of table structure model | None |
-| table_char_dict_path | the dict path of table structure model | ../ppocr/utils/dict/table_structure_dict.txt |
-| layout_path_model | The model path of the layout analysis model, which can be an online address or a local path. When it is a local path, layout_label_map needs to be set. In command line mode, use --layout_label_map='{0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"}' | lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config |
-| layout_label_map | Layout analysis model model label mapping dictionary path | None |
-| model_name_or_path | the model path of VQA SER model | None |
-| max_seq_length | the max token length of VQA SER model | 512 |
-| label_map_path | the label path of VQA SER model | ./vqa/labels/labels_ser.txt |
-| layout | Whether to perform layout analysis in forward | True |
-| table | Whether to perform table recognition in forward | True |
-| ocr | Whether to perform ocr for non-table areas in layout analysis. When layout is False, it will be automatically set to False | True |
-| structure_version | table structure Model version number, the current model support list is as follows: PP-STRUCTURE support english table structure model | PP-STRUCTURE |
+| field | description | default |
+|---|---|---|
+| output | result save path | ./output/table |
+| table_max_len | long side of the image resize in table structure model | 488 |
+| table_model_dir | Table structure model inference model path| None |
+| table_char_dict_path | The dictionary path of table structure model | ../ppocr/utils/dict/table_structure_dict.txt |
+| merge_no_span_structure | In the table recognition model, whether to merge '\' and '\ | ' | False |
+| layout_model_dir | Layout analysis model inference model path| None |
+| layout_dict_path | The dictionary path of layout analysis model| ../ppocr/utils/dict/layout_publaynet_dict.txt |
+| layout_score_threshold | The box threshold path of layout analysis model| 0.5|
+| layout_nms_threshold | The nms threshold path of layout analysis model| 0.5|
+| vqa_algorithm | vqa model algorithm| LayoutXLM|
+| ser_model_dir | Ser model inference model path| None|
+| ser_dict_path | The dictionary path of Ser model| ../train_data/XFUND/class_list_xfun.txt|
+| mode | structure or vqa | structure |
+| image_orientation | Whether to perform image orientation classification in forward | False |
+| layout | Whether to perform layout analysis in forward | True |
+| table | Whether to perform table recognition in forward | True |
+| ocr | Whether to perform ocr for non-table areas in layout analysis. When layout is False, it will be automatically set to False| True |
+| recovery | Whether to perform layout recovery in forward| False |
+| structure_version | Structure version, optional PP-structure and PP-structurev2 | PP-structure |
+
Most of the parameters are consistent with the PaddleOCR whl package, see [whl package documentation](../../doc/doc_en/whl.md)
diff --git a/ppstructure/predict_system.py b/ppstructure/predict_system.py
index 608f4d2fb3df6d8972fd1ae1d503e15b135d53e2..053a8aac00ffe762dd05d7f8030db9aaa32c0f8a 100644
--- a/ppstructure/predict_system.py
+++ b/ppstructure/predict_system.py
@@ -27,7 +27,6 @@ import numpy as np
import time
import logging
from copy import deepcopy
-from attrdict import AttrDict
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.logging import get_logger
@@ -44,6 +43,13 @@ class StructureSystem(object):
def __init__(self, args):
self.mode = args.mode
self.recovery = args.recovery
+
+ self.image_orientation_predictor = None
+ if args.image_orientation:
+ import paddleclas
+ self.image_orientation_predictor = paddleclas.PaddleClas(
+ model_name="text_image_orientation")
+
if self.mode == 'structure':
if not args.show_log:
logger.setLevel(logging.INFO)
@@ -74,6 +80,7 @@ class StructureSystem(object):
def __call__(self, img, return_ocr_result_in_table=False):
time_dict = {
+ 'image_orientation': 0,
'layout': 0,
'table': 0,
'table_match': 0,
@@ -83,6 +90,20 @@ class StructureSystem(object):
'all': 0
}
start = time.time()
+ if self.image_orientation_predictor is not None:
+ tic = time.time()
+ cls_result = self.image_orientation_predictor.predict(
+ input_data=img)
+ cls_res = next(cls_result)
+ angle = cls_res[0]['label_names'][0]
+ cv_rotate_code = {
+ '90': cv2.ROTATE_90_COUNTERCLOCKWISE,
+ '180': cv2.ROTATE_180,
+ '270': cv2.ROTATE_90_CLOCKWISE
+ }
+ img = cv2.rotate(img, cv_rotate_code[angle])
+ toc = time.time()
+ time_dict['image_orientation'] = toc - tic
if self.mode == 'structure':
ori_im = img.copy()
if self.layout_predictor is not None:
@@ -121,7 +142,10 @@ class StructureSystem(object):
roi_img)
time_dict['det'] += ocr_time_dict['det']
time_dict['rec'] += ocr_time_dict['rec']
- # remove style char
+
+ # remove style char,
+ # when using the recognition model trained on the PubtabNet dataset,
+ # it will recognize the text format in the table, such as
style_token = [
'', '', '', '', '',
'', '', '', '',
@@ -198,7 +222,6 @@ def main(args):
if img is None:
logger.error("error in loading image:{}".format(image_file))
continue
- starttime = time.time()
res, time_dict = structure_sys(img)
if structure_sys.mode == 'structure':
@@ -213,8 +236,7 @@ def main(args):
logger.info('result save to {}'.format(img_save_path))
if args.recovery:
convert_info_docx(img, res, save_folder, img_name)
- elapse = time.time() - starttime
- logger.info("Predict time : {:.3f}s".format(elapse))
+ logger.info("Predict time : {:.3f}s".format(time_dict['all']))
if __name__ == "__main__":
diff --git a/ppstructure/utility.py b/ppstructure/utility.py
index 597d9785168da3718ff222bed73b68360a6b4b86..fcba52b27d34e4fe4992d084da96e7b012014463 100644
--- a/ppstructure/utility.py
+++ b/ppstructure/utility.py
@@ -62,6 +62,11 @@ def init_args():
type=str,
default='structure',
help='structure and vqa is supported')
+ parser.add_argument(
+ "--image_orientation",
+ type=bool,
+ default=False,
+ help='Whether to enable image orientation recognition')
parser.add_argument(
"--layout",
type=str2bool,