提交 73c77ff7 编写于 作者: 文幕地方's avatar 文幕地方

add image_orientation and update quickstart

上级 c2c43bb1
......@@ -16,9 +16,14 @@ from ppocr.metrics.det_metric import DetMetric
class TableStructureMetric(object):
def __init__(self, main_indicator='acc', eps=1e-6, **kwargs):
def __init__(self,
main_indicator='acc',
eps=1e-6,
del_thead_tbody=False,
**kwargs):
self.main_indicator = main_indicator
self.eps = eps
self.del_thead_tbody = del_thead_tbody
self.reset()
def __call__(self, pred_label, batch=None, *args, **kwargs):
......@@ -31,6 +36,13 @@ class TableStructureMetric(object):
gt_structure_batch_list):
pred_str = ''.join(pred)
target_str = ''.join(target)
if self.del_thead_tbody:
pred_str = pred_str.replace('<thead>', '').replace(
'</thead>', '').replace('<tbody>', '').replace('</tbody>',
'')
target_str = target_str.replace('<thead>', '').replace(
'</thead>', '').replace('<tbody>', '').replace('</tbody>',
'')
if pred_str == target_str:
correct_num += 1
all_num += 1
......@@ -60,6 +72,7 @@ class TableMetric(object):
main_indicator='acc',
compute_bbox_metric=False,
box_format='xyxy',
del_thead_tbody=False,
**kwargs):
"""
......@@ -67,7 +80,8 @@ class TableMetric(object):
@param main_matric: main_matric for save best_model
@param kwargs:
"""
self.structure_metric = TableStructureMetric()
self.structure_metric = TableStructureMetric(
del_thead_tbody=del_thead_tbody)
self.bbox_metric = DetMetric() if compute_bbox_metric else None
self.main_indicator = main_indicator
self.box_format = box_format
......
......@@ -3,15 +3,17 @@
- [1. 安装依赖包](#1-安装依赖包)
- [2. 便捷使用](#2-便捷使用)
- [2.1 命令行使用](#21-命令行使用)
- [2.1.1 图像方向分类+版面分析+表格识别](#211-图像方向分类版面分析表格识别)
- [2.1.1 版面分析+表格识别](#211-版面分析表格识别)
- [2.1.2 版面分析](#212-版面分析)
- [2.1.3 表格识别](#213-表格识别)
- [2.1.4 DocVQA](#214-docvqa)
- [2.1.3 版面分析](#213-版面分析)
- [2.1.4 表格识别](#214-表格识别)
- [2.1.5 DocVQA](#215-docvqa)
- [2.2 代码使用](#22-代码使用)
- [2.2.1 版面分析+表格识别](#221-版面分析表格识别)
- [2.2.2 版面分析](#222-版面分析)
- [2.2.3 表格识别](#223-表格识别)
- [2.2.4 DocVQA](#224-docvqa)
- [2.2.1 图像方向分类版面分析表格识别](#221-图像方向分类版面分析表格识别)
- [2.2.2 版面分析+表格识别](#222-版面分析表格识别)
- [2.2.3 版面分析](#223-版面分析)
- [2.2.4 表格识别](#224-表格识别)
- [2.2.5 DocVQA](#225-docvqa)
- [2.3 返回结果说明](#23-返回结果说明)
- [2.3.1 版面分析+表格识别](#231-版面分析表格识别)
- [2.3.2 DocVQA](#232-docvqa)
......@@ -36,25 +38,31 @@ pip install paddlenlp
### 2.1 命令行使用
<a name="211"></a>
#### 2.1.1 图像方向分类+版面分析+表格识别
```bash
paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/1.png --type=structure --image_orientation=true
```
<a name="212"></a>
#### 2.1.1 版面分析+表格识别
```bash
paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/1.png --type=structure
```
<a name="212"></a>
#### 2.1.2 版面分析
<a name="213"></a>
#### 2.1.3 版面分析
```bash
paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/1.png --type=structure --table=false --ocr=false
```
<a name="213"></a>
#### 2.1.3 表格识别
<a name="214"></a>
#### 2.1.4 表格识别
```bash
paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/table.jpg --type=structure --layout=false
```
<a name="214"></a>
#### 2.1.4 DocVQA
<a name="215"></a>
#### 2.1.5 DocVQA
请参考:[文档视觉问答](../vqa/README.md)
......@@ -62,14 +70,14 @@ paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/table.jpg --type=structur
### 2.2 代码使用
<a name="221"></a>
#### 2.2.1 版面分析+表格识别
#### 2.2.1 图像方向分类版面分析表格识别
```python
import os
import cv2
from paddleocr import PPStructure,draw_structure_result,save_structure_res
table_engine = PPStructure(show_log=True)
table_engine = PPStructure(show_log=True, image_orientation=True)
save_folder = './output'
img_path = 'PaddleOCR/ppstructure/docs/table/1.png'
......@@ -91,7 +99,36 @@ im_show.save('result.jpg')
```
<a name="222"></a>
#### 2.2.2 版面分析
#### 2.2.2 版面分析+表格识别
```python
import os
import cv2
from paddleocr import PPStructure,draw_structure_result,save_structure_res
table_engine = PPStructure(show_log=True)
save_folder = './output'
img_path = 'PaddleOCR/ppstructure/docs/table/1.png'
img = cv2.imread(img_path)
result = table_engine(img)
save_structure_res(result, save_folder,os.path.basename(img_path).split('.')[0])
for line in result:
line.pop('img')
print(line)
from PIL import Image
font_path = 'PaddleOCR/doc/fonts/simfang.ttf' # PaddleOCR下提供字体包
image = Image.open(img_path).convert('RGB')
im_show = draw_structure_result(image, result,font_path=font_path)
im_show = Image.fromarray(im_show)
im_show.save('result.jpg')
```
<a name="223"></a>
#### 2.2.3 版面分析
```python
import os
......@@ -111,8 +148,8 @@ for line in result:
print(line)
```
<a name="223"></a>
#### 2.2.3 表格识别
<a name="224"></a>
#### 2.2.4 表格识别
```python
import os
......@@ -132,8 +169,8 @@ for line in result:
print(line)
```
<a name="224"></a>
#### 2.2.4 DocVQA
<a name="225"></a>
#### 2.2.5 DocVQA
请参考:[文档视觉问答](../vqa/README.md)
......@@ -154,10 +191,10 @@ PP-Structure的返回结果为一个dict组成的list,示例如下
```
dict 里各个字段说明如下
| 字段 | 说明 |
| --------------- |-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|type| 图片区域的类型 |
|bbox| 图片区域的在原图的坐标,分别[左上角x,左上角y,右下角x,右下角y] |
| 字段 | 说明|
| --- |---|
|type| 图片区域的类型 |
|bbox| 图片区域的在原图的坐标,分别[左上角x,左上角y,右下角x,右下角y]|
|res| 图片区域的OCR或表格识别结果。<br> 表格: 一个dict,字段说明如下<br>&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp; `html`: 表格的HTML字符串<br>&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp; 在代码使用模式下,前向传入return_ocr_result_in_table=True可以拿到表格中每个文本的检测识别结果,对应为如下字段: <br>&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp; `boxes`: 文本检测坐标<br>&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp; `rec_res`: 文本识别结果。<br> OCR: 一个包含各个单行文字的检测坐标和识别结果的元组 |
运行完成后,每张图片会在`output`字段指定的目录下有一个同名目录,图片里的每个表格会存储为一个excel,图片区域会被裁剪之后保存下来,excel文件和图片名为表格在图片里的坐标。
......@@ -178,20 +215,26 @@ dict 里各个字段说明如下
<a name="24"></a>
### 2.4 参数说明
| 字段 | 说明 | 默认值 |
|----------------------|----------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------|
| output | excel和识别结果保存的地址 | ./output/table |
| table_max_len | 表格结构模型预测时,图像的长边resize尺度 | 488 |
| table_model_dir | 表格结构模型 inference 模型地址 | None |
| table_char_dict_path | 表格结构模型所用字典地址 | ../ppocr/utils/dict/table_structure_dict.txt |
| layout_path_model | 版面分析模型模型地址,可以为在线地址或者本地地址,当为本地地址时,需要指定 layout_label_map, 命令行模式下可通过--layout_label_map='{0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"}' 指定 | lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config |
| layout_label_map | 版面分析模型模型label映射字典 | None |
| model_name_or_path | VQA SER模型地址 | None |
| max_seq_length | VQA SER模型最大支持token长度 | 512 |
| label_map_path | VQA SER 标签文件地址 | ./vqa/labels/labels_ser.txt |
| layout | 前向中是否执行版面分析 | True |
| table | 前向中是否执行表格识别 | True |
| ocr | 对于版面分析中的非表格区域,是否执行ocr。当layout为False时会被自动设置为False | True |
| structure_version | 表格结构化模型版本,可选 PP-STRUCTURE。PP-STRUCTURE支持表格结构化模型 | PP-STRUCTURE |
| 字段 | 说明 | 默认值 |
|---|---|---|
| output | 结果保存地址 | ./output/table |
| table_max_len | 表格结构模型预测时,图像的长边resize尺度 | 488 |
| table_model_dir | 表格结构模型 inference 模型地址| None |
| table_char_dict_path | 表格结构模型所用字典地址 | ../ppocr/utils/dict/table_structure_dict.txt |
| merge_no_span_structure | 表格识别模型中,是否对'\<td>'和'\</td>' 进行合并 | False |
| layout_model_dir | 版面分析模型 inference 模型地址 | None |
| layout_dict_path | 版面分析模型字典| ../ppocr/utils/dict/layout_publaynet_dict.txt |
| layout_score_threshold | 版面分析模型检测框阈值| 0.5|
| layout_nms_threshold | 版面分析模型nms阈值| 0.5|
| vqa_algorithm | vqa模型算法| LayoutXLM|
| ser_model_dir | ser模型 inference 模型地址| None|
| ser_dict_path | ser模型字典| ../train_data/XFUND/class_list_xfun.txt|
| mode | structure or vqa | structure |
| image_orientation | 前向中是否执行图像方向分类 | False |
| layout | 前向中是否执行版面分析 | True |
| table | 前向中是否执行表格识别 | True |
| ocr | 对于版面分析中的非表格区域,是否执行ocr。当layout为False时会被自动设置为False| True |
| recovery | 前向中是否执行版面恢复| False |
| structure_version | 模型版本,可选 PP-structure和PP-structurev2 | PP-structure |
大部分参数和PaddleOCR whl包保持一致,见 [whl包文档](../../doc/doc_ch/whl.md)
......@@ -3,15 +3,17 @@
- [1. Install package](#1-install-package)
- [2. Use](#2-use)
- [2.1 Use by command line](#21-use-by-command-line)
- [2.1.1 layout analysis + table recognition](#211-layout-analysis--table-recognition)
- [2.1.2 layout analysis](#212-layout-analysis)
- [2.1.3 table recognition](#213-table-recognition)
- [2.1.4 DocVQA](#214-docvqa)
- [2.1.1 image orientation + layout analysis + table recognition](#211-image-orientation--layout-analysis--table-recognition)
- [2.1.2 layout analysis + table recognition](#212-layout-analysis--table-recognition)
- [2.1.3 layout analysis](#213-layout-analysis)
- [2.1.4 table recognition](#214-table-recognition)
- [2.1.5 DocVQA](#215-docvqa)
- [2.2 Use by code](#22-use-by-code)
- [2.2.1 layout analysis + table recognition](#221-layout-analysis--table-recognition)
- [2.2.2 layout analysis](#222-layout-analysis)
- [2.2.3 table recognition](#223-table-recognition)
- [2.2.4 DocVQA](#224-docvqa)
- [2.2.1 image orientation + layout analysis + table recognition](#221-image-orientation--layout-analysis--table-recognition)
- [2.2.2 layout analysis + table recognition](#222-layout-analysis--table-recognition)
- [2.2.3 layout analysis](#223-layout-analysis)
- [2.2.4 table recognition](#224-table-recognition)
- [2.2.5 DocVQA](#225-docvqa)
- [2.3 Result description](#23-result-description)
- [2.3.1 layout analysis + table recognition](#231-layout-analysis--table-recognition)
- [2.3.2 DocVQA](#232-docvqa)
......@@ -36,25 +38,31 @@ pip install paddlenlp
### 2.1 Use by command line
<a name="211"></a>
#### 2.1.1 layout analysis + table recognition
#### 2.1.1 image orientation + layout analysis + table recognition
```bash
paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/1.png --type=structure
paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/1.png --type=structure --image_orientation=true
```
<a name="212"></a>
#### 2.1.2 layout analysis
#### 2.1.2 layout analysis + table recognition
```bash
paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/1.png --type=structure --table=false --ocr=false
paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/1.png --type=structure
```
<a name="213"></a>
#### 2.1.3 table recognition
#### 2.1.3 layout analysis
```bash
paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/table.jpg --type=structure --layout=false
paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/1.png --type=structure --table=false --ocr=false
```
<a name="214"></a>
#### 2.1.4 DocVQA
#### 2.1.4 table recognition
```bash
paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/table.jpg --type=structure --layout=false
```
<a name="215"></a>
#### 2.1.5 DocVQA
Please refer to: [Documentation Visual Q&A](../vqa/README.md) .
......@@ -62,14 +70,14 @@ Please refer to: [Documentation Visual Q&A](../vqa/README.md) .
### 2.2 Use by code
<a name="221"></a>
#### 2.2.1 layout analysis + table recognition
#### 2.2.1 image orientation + layout analysis + table recognition
```python
import os
import cv2
from paddleocr import PPStructure,draw_structure_result,save_structure_res
table_engine = PPStructure(show_log=True)
table_engine = PPStructure(show_log=True, image_orientation=True)
save_folder = './output'
img_path = 'PaddleOCR/ppstructure/docs/table/1.png'
......@@ -91,7 +99,36 @@ im_show.save('result.jpg')
```
<a name="222"></a>
#### 2.2.2 layout analysis
#### 2.2.2 layout analysis + table recognition
```python
import os
import cv2
from paddleocr import PPStructure,draw_structure_result,save_structure_res
table_engine = PPStructure(show_log=True)
save_folder = './output'
img_path = 'PaddleOCR/ppstructure/docs/table/1.png'
img = cv2.imread(img_path)
result = table_engine(img)
save_structure_res(result, save_folder,os.path.basename(img_path).split('.')[0])
for line in result:
line.pop('img')
print(line)
from PIL import Image
font_path = 'PaddleOCR/doc/fonts/simfang.ttf' # PaddleOCR下提供字体包
image = Image.open(img_path).convert('RGB')
im_show = draw_structure_result(image, result,font_path=font_path)
im_show = Image.fromarray(im_show)
im_show.save('result.jpg')
```
<a name="223"></a>
#### 2.2.3 layout analysis
```python
import os
......@@ -111,8 +148,8 @@ for line in result:
print(line)
```
<a name="223"></a>
#### 2.2.3 table recognition
<a name="224"></a>
#### 2.2.4 table recognition
```python
import os
......@@ -132,8 +169,8 @@ for line in result:
print(line)
```
<a name="224"></a>
#### 2.2.4 DocVQA
<a name="225"></a>
#### 2.2.5 DocVQA
Please refer to: [Documentation Visual Q&A](../vqa/README.md) .
......@@ -155,8 +192,8 @@ The return of PP-Structure is a list of dicts, the example is as follows:
```
Each field in dict is described as follows:
| field | description |
| --------------- |--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| field | description |
| --- |---|
|type| Type of image area. |
|bbox| The coordinates of the image area in the original image, respectively [upper left corner x, upper left corner y, lower right corner x, lower right corner y]. |
|res| OCR or table recognition result of the image area. <br> table: a dict with field descriptions as follows: <br>&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp; `html`: html str of table.<br>&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp; In the code usage mode, set return_ocr_result_in_table=True whrn call can get the detection and recognition results of each text in the table area, corresponding to the following fields: <br>&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp; `boxes`: text detection boxes.<br>&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp; `rec_res`: text recognition results.<br> OCR: A tuple containing the detection boxes and recognition results of each single text. |
......@@ -178,19 +215,26 @@ Please refer to: [Documentation Visual Q&A](../vqa/README.md) .
<a name="24"></a>
### 2.4 Parameter Description
| field | description | default |
|----------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------|
| output | The save path of result | ./output/table |
| table_max_len | When the table structure model predicts, the long side of the image | 488 |
| table_model_dir | the path of table structure model | None |
| table_char_dict_path | the dict path of table structure model | ../ppocr/utils/dict/table_structure_dict.txt |
| layout_path_model | The model path of the layout analysis model, which can be an online address or a local path. When it is a local path, layout_label_map needs to be set. In command line mode, use --layout_label_map='{0: "Text", 1: "Title", 2: "List", 3:"Table", 4:"Figure"}' | lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config |
| layout_label_map | Layout analysis model model label mapping dictionary path | None |
| model_name_or_path | the model path of VQA SER model | None |
| max_seq_length | the max token length of VQA SER model | 512 |
| label_map_path | the label path of VQA SER model | ./vqa/labels/labels_ser.txt |
| layout | Whether to perform layout analysis in forward | True |
| table | Whether to perform table recognition in forward | True |
| ocr | Whether to perform ocr for non-table areas in layout analysis. When layout is False, it will be automatically set to False | True |
| structure_version | table structure Model version number, the current model support list is as follows: PP-STRUCTURE support english table structure model | PP-STRUCTURE |
| field | description | default |
|---|---|---|
| output | result save path | ./output/table |
| table_max_len | long side of the image resize in table structure model | 488 |
| table_model_dir | Table structure model inference model path| None |
| table_char_dict_path | The dictionary path of table structure model | ../ppocr/utils/dict/table_structure_dict.txt |
| merge_no_span_structure | In the table recognition model, whether to merge '\<td>' and '\</td>' | False |
| layout_model_dir | Layout analysis model inference model path| None |
| layout_dict_path | The dictionary path of layout analysis model| ../ppocr/utils/dict/layout_publaynet_dict.txt |
| layout_score_threshold | The box threshold path of layout analysis model| 0.5|
| layout_nms_threshold | The nms threshold path of layout analysis model| 0.5|
| vqa_algorithm | vqa model algorithm| LayoutXLM|
| ser_model_dir | Ser model inference model path| None|
| ser_dict_path | The dictionary path of Ser model| ../train_data/XFUND/class_list_xfun.txt|
| mode | structure or vqa | structure |
| image_orientation | Whether to perform image orientation classification in forward | False |
| layout | Whether to perform layout analysis in forward | True |
| table | Whether to perform table recognition in forward | True |
| ocr | Whether to perform ocr for non-table areas in layout analysis. When layout is False, it will be automatically set to False| True |
| recovery | Whether to perform layout recovery in forward| False |
| structure_version | Structure version, optional PP-structure and PP-structurev2 | PP-structure |
Most of the parameters are consistent with the PaddleOCR whl package, see [whl package documentation](../../doc/doc_en/whl.md)
......@@ -27,7 +27,6 @@ import numpy as np
import time
import logging
from copy import deepcopy
from attrdict import AttrDict
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.logging import get_logger
......@@ -44,6 +43,13 @@ class StructureSystem(object):
def __init__(self, args):
self.mode = args.mode
self.recovery = args.recovery
self.image_orientation_predictor = None
if args.image_orientation:
import paddleclas
self.image_orientation_predictor = paddleclas.PaddleClas(
model_name="text_image_orientation")
if self.mode == 'structure':
if not args.show_log:
logger.setLevel(logging.INFO)
......@@ -74,6 +80,7 @@ class StructureSystem(object):
def __call__(self, img, return_ocr_result_in_table=False):
time_dict = {
'image_orientation': 0,
'layout': 0,
'table': 0,
'table_match': 0,
......@@ -83,6 +90,20 @@ class StructureSystem(object):
'all': 0
}
start = time.time()
if self.image_orientation_predictor is not None:
tic = time.time()
cls_result = self.image_orientation_predictor.predict(
input_data=img)
cls_res = next(cls_result)
angle = cls_res[0]['label_names'][0]
cv_rotate_code = {
'90': cv2.ROTATE_90_COUNTERCLOCKWISE,
'180': cv2.ROTATE_180,
'270': cv2.ROTATE_90_CLOCKWISE
}
img = cv2.rotate(img, cv_rotate_code[angle])
toc = time.time()
time_dict['image_orientation'] = toc - tic
if self.mode == 'structure':
ori_im = img.copy()
if self.layout_predictor is not None:
......@@ -121,7 +142,10 @@ class StructureSystem(object):
roi_img)
time_dict['det'] += ocr_time_dict['det']
time_dict['rec'] += ocr_time_dict['rec']
# remove style char
# remove style char,
# when using the recognition model trained on the PubtabNet dataset,
# it will recognize the text format in the table, such as <b>
style_token = [
'<strike>', '<strike>', '<sup>', '</sub>', '<b>',
'</b>', '<sub>', '</sup>', '<overline>',
......@@ -198,7 +222,6 @@ def main(args):
if img is None:
logger.error("error in loading image:{}".format(image_file))
continue
starttime = time.time()
res, time_dict = structure_sys(img)
if structure_sys.mode == 'structure':
......@@ -213,8 +236,7 @@ def main(args):
logger.info('result save to {}'.format(img_save_path))
if args.recovery:
convert_info_docx(img, res, save_folder, img_name)
elapse = time.time() - starttime
logger.info("Predict time : {:.3f}s".format(elapse))
logger.info("Predict time : {:.3f}s".format(time_dict['all']))
if __name__ == "__main__":
......
......@@ -62,6 +62,11 @@ def init_args():
type=str,
default='structure',
help='structure and vqa is supported')
parser.add_argument(
"--image_orientation",
type=bool,
default=False,
help='Whether to enable image orientation recognition')
parser.add_argument(
"--layout",
type=str2bool,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册