提交 bbff7c55 编写于 作者: 文幕地方's avatar 文幕地方

The whl package supports separate table recognition and layout analysis

上级 d7b107d1
...@@ -47,7 +47,7 @@ __all__ = [ ...@@ -47,7 +47,7 @@ __all__ = [
] ]
SUPPORT_DET_MODEL = ['DB'] SUPPORT_DET_MODEL = ['DB']
VERSION = '2.4.0.4' VERSION = '2.5'
SUPPORT_REC_MODEL = ['CRNN'] SUPPORT_REC_MODEL = ['CRNN']
BASE_DIR = os.path.expanduser("~/.paddleocr/") BASE_DIR = os.path.expanduser("~/.paddleocr/")
...@@ -442,7 +442,7 @@ class PPStructure(StructureSystem): ...@@ -442,7 +442,7 @@ class PPStructure(StructureSystem):
logger.debug(params) logger.debug(params)
super().__init__(params) super().__init__(params)
def __call__(self, img): def __call__(self, img, return_ocr_result_in_table=False):
if isinstance(img, str): if isinstance(img, str):
# download net image # download net image
if img.startswith('http'): if img.startswith('http'):
...@@ -460,7 +460,7 @@ class PPStructure(StructureSystem): ...@@ -460,7 +460,7 @@ class PPStructure(StructureSystem):
if isinstance(img, np.ndarray) and len(img.shape) == 2: if isinstance(img, np.ndarray) and len(img.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
res = super().__call__(img) res = super().__call__(img, return_ocr_result_in_table)
return res return res
......
...@@ -73,7 +73,7 @@ class BaseRecLabelDecode(object): ...@@ -73,7 +73,7 @@ class BaseRecLabelDecode(object):
conf_list = [0] conf_list = [0]
text = ''.join(char_list) text = ''.join(char_list)
result_list.append((text, np.mean(conf_list))) result_list.append((text, np.mean(conf_list).tolist()))
return result_list return result_list
def get_ignored_tokens(self): def get_ignored_tokens(self):
...@@ -196,7 +196,7 @@ class NRTRLabelDecode(BaseRecLabelDecode): ...@@ -196,7 +196,7 @@ class NRTRLabelDecode(BaseRecLabelDecode):
else: else:
conf_list.append(1) conf_list.append(1)
text = ''.join(char_list) text = ''.join(char_list)
result_list.append((text.lower(), np.mean(conf_list))) result_list.append((text.lower(), np.mean(conf_list).tolist()))
return result_list return result_list
...@@ -241,7 +241,7 @@ class AttnLabelDecode(BaseRecLabelDecode): ...@@ -241,7 +241,7 @@ class AttnLabelDecode(BaseRecLabelDecode):
else: else:
conf_list.append(1) conf_list.append(1)
text = ''.join(char_list) text = ''.join(char_list)
result_list.append((text, np.mean(conf_list))) result_list.append((text, np.mean(conf_list).tolist()))
return result_list return result_list
def __call__(self, preds, label=None, *args, **kwargs): def __call__(self, preds, label=None, *args, **kwargs):
...@@ -333,7 +333,7 @@ class SEEDLabelDecode(BaseRecLabelDecode): ...@@ -333,7 +333,7 @@ class SEEDLabelDecode(BaseRecLabelDecode):
else: else:
conf_list.append(1) conf_list.append(1)
text = ''.join(char_list) text = ''.join(char_list)
result_list.append((text, np.mean(conf_list))) result_list.append((text, np.mean(conf_list).tolist()))
return result_list return result_list
def __call__(self, preds, label=None, *args, **kwargs): def __call__(self, preds, label=None, *args, **kwargs):
...@@ -417,7 +417,7 @@ class SRNLabelDecode(BaseRecLabelDecode): ...@@ -417,7 +417,7 @@ class SRNLabelDecode(BaseRecLabelDecode):
conf_list.append(1) conf_list.append(1)
text = ''.join(char_list) text = ''.join(char_list)
result_list.append((text, np.mean(conf_list))) result_list.append((text, np.mean(conf_list).tolist()))
return result_list return result_list
def add_special_char(self, dict_character): def add_special_char(self, dict_character):
...@@ -636,7 +636,7 @@ class SARLabelDecode(BaseRecLabelDecode): ...@@ -636,7 +636,7 @@ class SARLabelDecode(BaseRecLabelDecode):
comp = re.compile('[^A-Z^a-z^0-9^\u4e00-\u9fa5]') comp = re.compile('[^A-Z^a-z^0-9^\u4e00-\u9fa5]')
text = text.lower() text = text.lower()
text = comp.sub('', text) text = comp.sub('', text)
result_list.append((text, np.mean(conf_list))) result_list.append((text, np.mean(conf_list).tolist()))
return result_list return result_list
def __call__(self, preds, label=None, *args, **kwargs): def __call__(self, preds, label=None, *args, **kwargs):
...@@ -699,7 +699,7 @@ class PRENLabelDecode(BaseRecLabelDecode): ...@@ -699,7 +699,7 @@ class PRENLabelDecode(BaseRecLabelDecode):
text = ''.join(char_list) text = ''.join(char_list)
if len(text) > 0: if len(text) > 0:
result_list.append((text, np.mean(conf_list))) result_list.append((text, np.mean(conf_list).tolist()))
else: else:
# here confidence of empty recog result is 1 # here confidence of empty recog result is 1
result_list.append(('', 1)) result_list.append(('', 1))
......
# 基于Python预测引擎推理 # 基于Python预测引擎推理
- [版面分析+表格识别](#1) - [1. Structure](#1)
- [DocVQA](#2) - [1.1 版面分析+表格识别](#1.1)
- [1.2 版面分析](#1.2)
- [1.3 表格识别](#1.3)
- [2. DocVQA](#2)
<a name="1"></a> <a name="1"></a>
## 1. 版面分析+表格识别 ## 1. Structure
进入`ppstructure`目录
```bash ```bash
cd ppstructure cd ppstructure
````
# 下载模型 下载模型
```bash
mkdir inference && cd inference mkdir inference && cd inference
# 下载PP-OCRv2文本检测模型并解压 # 下载PP-OCRv2文本检测模型并解压
wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_slim_quant_infer.tar && tar xf ch_PP-OCRv2_det_slim_quant_infer.tar wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_slim_quant_infer.tar && tar xf ch_PP-OCRv2_det_slim_quant_infer.tar
...@@ -18,17 +23,42 @@ wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_slim_quant ...@@ -18,17 +23,42 @@ wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_slim_quant
# 下载超轻量级英文表格预测模型并解压 # 下载超轻量级英文表格预测模型并解压
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
cd .. cd ..
```
<a name="1.1"></a>
### 1.1 版面分析+表格识别
```bash
python3 predict_system.py --det_model_dir=inference/ch_PP-OCRv2_det_slim_quant_infer \ python3 predict_system.py --det_model_dir=inference/ch_PP-OCRv2_det_slim_quant_infer \
--rec_model_dir=inference/ch_PP-OCRv2_rec_slim_quant_infer \ --rec_model_dir=inference/ch_PP-OCRv2_rec_slim_quant_infer \
--table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer \ --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer \
--image_dir=../doc/table/1.png \ --image_dir=./docs/table/1.png \
--rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt \ --rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt \
--table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \ --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \
--output=../output/table \ --output=../output \
--vis_font_path=../doc/fonts/simfang.ttf --vis_font_path=../doc/fonts/simfang.ttf
``` ```
运行完成后,每张图片会在`output`字段指定的目录下的`talbe`目录下有一个同名目录,图片里的每个表格会存储为一个excel,图片区域会被裁剪之后保存下来,excel文件和图片名名为表格在图片里的坐标。 运行完成后,每张图片会在`output`字段指定的目录下的`structure`目录下有一个同名目录,图片里的每个表格会存储为一个excel,图片区域会被裁剪之后保存下来,excel文件和图片名为表格在图片里的坐标。详细的结果会存储在`res.txt`文件中。
<a name="1.2"></a>
### 1.2 版面分析
```bash
python3 predict_system.py --image_dir=./docs/table/1.png --table=false --ocr=false --output=../output/
```
运行完成后,每张图片会在`output`字段指定的目录下的`structure`目录下有一个同名目录,图片区域会被裁剪之后保存下来,图片名为表格在图片里的坐标。版面分析结果会存储在`res.txt`文件中。
<a name="1.3"></a>
### 1.3 表格识别
```bash
python3 predict_system.py --det_model_dir=inference/ch_PP-OCRv2_det_slim_quant_infer \
--rec_model_dir=inference/ch_PP-OCRv2_rec_slim_quant_infer \
--table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer \
--image_dir=./docs/table/table.jpg \
--rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt \
--table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \
--output=../output \
--vis_font_path=../doc/fonts/simfang.ttf \
--layout=false
```
运行完成后,每张图片会在`output`字段指定的目录下的`structure`目录下有一个同名目录,表格会存储为一个excel,excel文件名为`[0,0,img_h,img_w]`。
<a name="2"></a> <a name="2"></a>
## 2. DocVQA ## 2. DocVQA
......
# 基于Python预测引擎推理 # 基于Python预测引擎推理
- [版面分析+表格识别](#1) - [1. Structure](#1)
- [DocVQA](#2) - [1.1 版面分析+表格识别](#1.1)
- [1.2 版面分析](#1.2)
- [1.3 表格识别](#1.3)
- [2. DocVQA](#2)
<a name="1"></a> <a name="1"></a>
## 1. 版面分析+表格识别 ## 1. Structure
进入`ppstructure`目录
```bash ```bash
cd ppstructure cd ppstructure
````
# 下载模型 下载模型
```bash
mkdir inference && cd inference mkdir inference && cd inference
# 下载PP-OCRv2文本检测模型并解压 # 下载PP-OCRv2文本检测模型并解压
wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_slim_quant_infer.tar && tar xf ch_PP-OCRv2_det_slim_quant_infer.tar wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_slim_quant_infer.tar && tar xf ch_PP-OCRv2_det_slim_quant_infer.tar
...@@ -18,17 +23,42 @@ wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_slim_quant ...@@ -18,17 +23,42 @@ wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_slim_quant
# 下载超轻量级英文表格预测模型并解压 # 下载超轻量级英文表格预测模型并解压
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
cd .. cd ..
```
<a name="1.1"></a>
### 1.1 版面分析+表格识别
```bash
python3 predict_system.py --det_model_dir=inference/ch_PP-OCRv2_det_slim_quant_infer \ python3 predict_system.py --det_model_dir=inference/ch_PP-OCRv2_det_slim_quant_infer \
--rec_model_dir=inference/ch_PP-OCRv2_rec_slim_quant_infer \ --rec_model_dir=inference/ch_PP-OCRv2_rec_slim_quant_infer \
--table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer \ --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer \
--image_dir=../doc/table/1.png \ --image_dir=./docs/table/1.png \
--rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt \ --rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt \
--table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \ --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \
--output=../output/table \ --output=../output \
--vis_font_path=../doc/fonts/simfang.ttf --vis_font_path=../doc/fonts/simfang.ttf
``` ```
运行完成后,每张图片会在`output`字段指定的目录下的`talbe`目录下有一个同名目录,图片里的每个表格会存储为一个excel,图片区域会被裁剪之后保存下来,excel文件和图片名名为表格在图片里的坐标。 运行完成后,每张图片会在`output`字段指定的目录下的`structure`目录下有一个同名目录,图片里的每个表格会存储为一个excel,图片区域会被裁剪之后保存下来,excel文件和图片名为表格在图片里的坐标。详细的结果会存储在`res.txt`文件中。
<a name="1.2"></a>
### 1.2 版面分析
```bash
python3 predict_system.py --image_dir=./docs/table/1.png --table=false --ocr=false --output=../output/
```
运行完成后,每张图片会在`output`字段指定的目录下的`structure`目录下有一个同名目录,图片区域会被裁剪之后保存下来,图片名为表格在图片里的坐标。版面分析结果会存储在`res.txt`文件中。
<a name="1.3"></a>
### 1.3 表格识别
```bash
python3 predict_system.py --det_model_dir=inference/ch_PP-OCRv2_det_slim_quant_infer \
--rec_model_dir=inference/ch_PP-OCRv2_rec_slim_quant_infer \
--table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer \
--image_dir=./docs/table/table.jpg \
--rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt \
--table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt \
--output=../output \
--vis_font_path=../doc/fonts/simfang.ttf \
--layout=false
```
运行完成后,每张图片会在`output`字段指定的目录下的`structure`目录下有一个同名目录,表格会存储为一个excel,excel文件名为`[0,0,img_h,img_w]`。
<a name="2"></a> <a name="2"></a>
## 2. DocVQA ## 2. DocVQA
......
...@@ -4,10 +4,14 @@ ...@@ -4,10 +4,14 @@
- [2. 便捷使用](#2) - [2. 便捷使用](#2)
- [2.1 命令行使用](#21) - [2.1 命令行使用](#21)
- [2.1.1 版面分析+表格识别](#211) - [2.1.1 版面分析+表格识别](#211)
- [2.1.2 DocVQA](#212) - [2.1.2 版面分析](#212)
- [2.2 Python脚本使用](#22) - [2.1.3 表格识别](#213)
- [2.1.4 DocVQA](#214)
- [2.2 代码使用](#22)
- [2.2.1 版面分析+表格识别](#221) - [2.2.1 版面分析+表格识别](#221)
- [2.2.2 DocVQA](#222) - [2.2.2 版面分析](#222)
- [2.2.3 表格识别](#223)
- [2.2.4 DocVQA](#224)
- [2.3 返回结果说明](#23) - [2.3 返回结果说明](#23)
- [2.3.1 版面分析+表格识别](#231) - [2.3.1 版面分析+表格识别](#231)
- [2.3.2 DocVQA](#232) - [2.3.2 DocVQA](#232)
...@@ -36,16 +40,28 @@ pip install paddlenlp ...@@ -36,16 +40,28 @@ pip install paddlenlp
<a name="211"></a> <a name="211"></a>
#### 2.1.1 版面分析+表格识别 #### 2.1.1 版面分析+表格识别
```bash ```bash
paddleocr --image_dir=../doc/table/1.png --type=structure paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/1.png --type=structure
``` ```
<a name="212"></a> <a name="212"></a>
#### 2.1.2 DocVQA #### 2.1.2 版面分析
```bash
paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/1.png --type=structure --table=false --ocr=false
```
<a name="213"></a>
#### 2.1.3 表格识别
```bash
paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/table.jpg --type=structure --layout=false
```
<a name="214"></a>
#### 2.1.4 DocVQA
请参考:[文档视觉问答](../vqa/README.md) 请参考:[文档视觉问答](../vqa/README.md)
<a name="22"></a> <a name="22"></a>
### 2.2 Python脚本使用 ### 2.2 代码使用
<a name="221"></a> <a name="221"></a>
#### 2.2.1 版面分析+表格识别 #### 2.2.1 版面分析+表格识别
...@@ -57,8 +73,8 @@ from paddleocr import PPStructure,draw_structure_result,save_structure_res ...@@ -57,8 +73,8 @@ from paddleocr import PPStructure,draw_structure_result,save_structure_res
table_engine = PPStructure(show_log=True) table_engine = PPStructure(show_log=True)
save_folder = './output/table' save_folder = './output'
img_path = '../doc/table/1.png' img_path = 'PaddleOCR/ppstructure/docs/table/1.png'
img = cv2.imread(img_path) img = cv2.imread(img_path)
result = table_engine(img) result = table_engine(img)
save_structure_res(result, save_folder,os.path.basename(img_path).split('.')[0]) save_structure_res(result, save_folder,os.path.basename(img_path).split('.')[0])
...@@ -69,7 +85,7 @@ for line in result: ...@@ -69,7 +85,7 @@ for line in result:
from PIL import Image from PIL import Image
font_path = '../doc/fonts/simfang.ttf' # PaddleOCR下提供字体包 font_path = 'PaddleOCR/doc/fonts/simfang.ttf' # PaddleOCR下提供字体包
image = Image.open(img_path).convert('RGB') image = Image.open(img_path).convert('RGB')
im_show = draw_structure_result(image, result,font_path=font_path) im_show = draw_structure_result(image, result,font_path=font_path)
im_show = Image.fromarray(im_show) im_show = Image.fromarray(im_show)
...@@ -77,7 +93,49 @@ im_show.save('result.jpg') ...@@ -77,7 +93,49 @@ im_show.save('result.jpg')
``` ```
<a name="222"></a> <a name="222"></a>
#### 2.2.2 DocVQA #### 2.2.2 版面分析
```python
import os
import cv2
from paddleocr import PPStructure,save_structure_res
table_engine = PPStructure(table=False, ocr=False, show_log=True)
save_folder = './output'
img_path = 'PaddleOCR/ppstructure/docs/table/1.png'
img = cv2.imread(img_path)
result = table_engine(img)
save_structure_res(result, save_folder, os.path.basename(img_path).split('.')[0])
for line in result:
line.pop('img')
print(line)
```
<a name="223"></a>
#### 2.2.3 表格识别
```python
import os
import cv2
from paddleocr import PPStructure,save_structure_res
table_engine = PPStructure(layout=False, show_log=True)
save_folder = './output'
img_path = 'PaddleOCR/ppstructure/docs/table/table.jpg'
img = cv2.imread(img_path)
result = table_engine(img)
save_structure_res(result, save_folder, os.path.basename(img_path).split('.')[0])
for line in result:
line.pop('img')
print(line)
```
<a name="224"></a>
#### 2.2.4 DocVQA
请参考:[文档视觉问答](../vqa/README.md) 请参考:[文档视觉问答](../vqa/README.md)
...@@ -99,10 +157,10 @@ PP-Structure的返回结果为一个dict组成的list,示例如下 ...@@ -99,10 +157,10 @@ PP-Structure的返回结果为一个dict组成的list,示例如下
dict 里各个字段说明如下 dict 里各个字段说明如下
| 字段 | 说明 | | 字段 | 说明 |
| --------------- | -------------| | --------------- |-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|type|图片区域的类型| |type| 图片区域的类型 |
|bbox|图片区域的在原图的坐标,分别[左上角x,左上角y,右下角x,右下角y]| |bbox| 图片区域的在原图的坐标,分别[左上角x,左上角y,右下角x,右下角y] |
|res|图片区域的OCR或表格识别结果。<br> 表格: 表格的HTML字符串; <br> OCR: 一个包含各个单行文字的检测坐标和识别结果的元组| |res| 图片区域的OCR或表格识别结果。<br> 表格: 一个dict,字段说明如下<br>&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp; `html`: 表格的HTML字符串<br>&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp; 在代码使用模式下,前向传入return_ocr_result_in_table=True可以拿到表格中每个文本的检测识别结果,对应为如下字段: <br>&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp; `boxes`: 文本检测坐标<br>&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp; `rec_res`: 文本识别结果。<br> OCR: 一个包含各个单行文字的检测坐标和识别结果的元组 |
运行完成后,每张图片会在`output`字段指定的目录下有一个同名目录,图片里的每个表格会存储为一个excel,图片区域会被裁剪之后保存下来,excel文件和图片名为表格在图片里的坐标。 运行完成后,每张图片会在`output`字段指定的目录下有一个同名目录,图片里的每个表格会存储为一个excel,图片区域会被裁剪之后保存下来,excel文件和图片名为表格在图片里的坐标。
...@@ -123,7 +181,7 @@ dict 里各个字段说明如下 ...@@ -123,7 +181,7 @@ dict 里各个字段说明如下
### 2.4 参数说明 ### 2.4 参数说明
| 字段 | 说明 | 默认值 | | 字段 | 说明 | 默认值 |
| --------------- | ---------------------------------------- | ------------------------------------------- | |----------------------|----------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------|
| output | excel和识别结果保存的地址 | ./output/table | | output | excel和识别结果保存的地址 | ./output/table |
| table_max_len | 表格结构模型预测时,图像的长边resize尺度 | 488 | | table_max_len | 表格结构模型预测时,图像的长边resize尺度 | 488 |
| table_model_dir | 表格结构模型 inference 模型地址 | None | | table_model_dir | 表格结构模型 inference 模型地址 | None |
...@@ -134,5 +192,8 @@ dict 里各个字段说明如下 ...@@ -134,5 +192,8 @@ dict 里各个字段说明如下
| max_seq_length | VQA SER模型最大支持token长度 | 512 | | max_seq_length | VQA SER模型最大支持token长度 | 512 |
| label_map_path | VQA SER 标签文件地址 | ./vqa/labels/labels_ser.txt | | label_map_path | VQA SER 标签文件地址 | ./vqa/labels/labels_ser.txt |
| mode | pipeline预测模式,structure: 版面分析+表格识别; VQA: SER文档信息抽取 | structure | | mode | pipeline预测模式,structure: 版面分析+表格识别; VQA: SER文档信息抽取 | structure |
| layout | 前向中是否执行版面分析 | True |
| table | 前向中是否执行表格识别 | True |
| ocr | 对于版面分析中的非表格区域,是否执行ocr。当layout为False时会被自动设置为False | True |
大部分参数和PaddleOCR whl包保持一致,见 [whl包文档](../../doc/doc_ch/whl.md) 大部分参数和PaddleOCR whl包保持一致,见 [whl包文档](../../doc/doc_ch/whl.md)
...@@ -4,10 +4,14 @@ ...@@ -4,10 +4,14 @@
- [2. 便捷使用](#2) - [2. 便捷使用](#2)
- [2.1 命令行使用](#21) - [2.1 命令行使用](#21)
- [2.1.1 版面分析+表格识别](#211) - [2.1.1 版面分析+表格识别](#211)
- [2.1.2 DocVQA](#212) - [2.1.2 版面分析](#212)
- [2.2 Python脚本使用](#22) - [2.1.3 表格识别](#213)
- [2.1.4 DocVQA](#214)
- [2.2 代码使用](#22)
- [2.2.1 版面分析+表格识别](#221) - [2.2.1 版面分析+表格识别](#221)
- [2.2.2 DocVQA](#222) - [2.2.2 版面分析](#222)
- [2.2.3 表格识别](#223)
- [2.2.4 DocVQA](#224)
- [2.3 返回结果说明](#23) - [2.3 返回结果说明](#23)
- [2.3.1 版面分析+表格识别](#231) - [2.3.1 版面分析+表格识别](#231)
- [2.3.2 DocVQA](#232) - [2.3.2 DocVQA](#232)
...@@ -36,16 +40,28 @@ pip install paddlenlp ...@@ -36,16 +40,28 @@ pip install paddlenlp
<a name="211"></a> <a name="211"></a>
#### 2.1.1 版面分析+表格识别 #### 2.1.1 版面分析+表格识别
```bash ```bash
paddleocr --image_dir=../doc/table/1.png --type=structure paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/1.png --type=structure
``` ```
<a name="212"></a> <a name="212"></a>
#### 2.1.2 DocVQA #### 2.1.2 版面分析
```bash
paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/1.png --type=structure --table=false --ocr=false
```
<a name="213"></a>
#### 2.1.3 表格识别
```bash
paddleocr --image_dir=PaddleOCR/ppstructure/docs/table/table.jpg --type=structure --layout=false
```
<a name="214"></a>
#### 2.1.4 DocVQA
请参考:[文档视觉问答](../vqa/README.md) 请参考:[文档视觉问答](../vqa/README.md)
<a name="22"></a> <a name="22"></a>
### 2.2 Python脚本使用 ### 2.2 代码使用
<a name="221"></a> <a name="221"></a>
#### 2.2.1 版面分析+表格识别 #### 2.2.1 版面分析+表格识别
...@@ -57,8 +73,8 @@ from paddleocr import PPStructure,draw_structure_result,save_structure_res ...@@ -57,8 +73,8 @@ from paddleocr import PPStructure,draw_structure_result,save_structure_res
table_engine = PPStructure(show_log=True) table_engine = PPStructure(show_log=True)
save_folder = './output/table' save_folder = './output'
img_path = '../doc/table/1.png' img_path = 'PaddleOCR/ppstructure/docs/table/1.png'
img = cv2.imread(img_path) img = cv2.imread(img_path)
result = table_engine(img) result = table_engine(img)
save_structure_res(result, save_folder,os.path.basename(img_path).split('.')[0]) save_structure_res(result, save_folder,os.path.basename(img_path).split('.')[0])
...@@ -69,7 +85,7 @@ for line in result: ...@@ -69,7 +85,7 @@ for line in result:
from PIL import Image from PIL import Image
font_path = '../doc/fonts/simfang.ttf' # PaddleOCR下提供字体包 font_path = 'PaddleOCR/doc/fonts/simfang.ttf' # PaddleOCR下提供字体包
image = Image.open(img_path).convert('RGB') image = Image.open(img_path).convert('RGB')
im_show = draw_structure_result(image, result,font_path=font_path) im_show = draw_structure_result(image, result,font_path=font_path)
im_show = Image.fromarray(im_show) im_show = Image.fromarray(im_show)
...@@ -77,7 +93,49 @@ im_show.save('result.jpg') ...@@ -77,7 +93,49 @@ im_show.save('result.jpg')
``` ```
<a name="222"></a> <a name="222"></a>
#### 2.2.2 DocVQA #### 2.2.2 版面分析
```python
import os
import cv2
from paddleocr import PPStructure,save_structure_res
table_engine = PPStructure(table=False, ocr=False, show_log=True)
save_folder = './output'
img_path = 'PaddleOCR/ppstructure/docs/table/1.png'
img = cv2.imread(img_path)
result = table_engine(img)
save_structure_res(result, save_folder, os.path.basename(img_path).split('.')[0])
for line in result:
line.pop('img')
print(line)
```
<a name="223"></a>
#### 2.2.3 表格识别
```python
import os
import cv2
from paddleocr import PPStructure,save_structure_res
table_engine = PPStructure(layout=False, show_log=True)
save_folder = './output'
img_path = 'PaddleOCR/ppstructure/docs/table/table.jpg'
img = cv2.imread(img_path)
result = table_engine(img)
save_structure_res(result, save_folder, os.path.basename(img_path).split('.')[0])
for line in result:
line.pop('img')
print(line)
```
<a name="224"></a>
#### 2.2.4 DocVQA
请参考:[文档视觉问答](../vqa/README.md) 请参考:[文档视觉问答](../vqa/README.md)
...@@ -99,10 +157,10 @@ PP-Structure的返回结果为一个dict组成的list,示例如下 ...@@ -99,10 +157,10 @@ PP-Structure的返回结果为一个dict组成的list,示例如下
dict 里各个字段说明如下 dict 里各个字段说明如下
| 字段 | 说明 | | 字段 | 说明 |
| --------------- | -------------| | --------------- |-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|type|图片区域的类型| |type| 图片区域的类型 |
|bbox|图片区域的在原图的坐标,分别[左上角x,左上角y,右下角x,右下角y]| |bbox| 图片区域的在原图的坐标,分别[左上角x,左上角y,右下角x,右下角y] |
|res|图片区域的OCR或表格识别结果。<br> 表格: 表格的HTML字符串; <br> OCR: 一个包含各个单行文字的检测坐标和识别结果的元组| |res| 图片区域的OCR或表格识别结果。<br> 表格: 一个dict,字段说明如下<br>&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp; `html`: 表格的HTML字符串<br>&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp; 在代码使用模式下,前向传入return_ocr_result_in_table=True可以拿到表格中每个文本的检测识别结果,对应为如下字段: <br>&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp; `boxes`: 文本检测坐标<br>&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp; `rec_res`: 文本识别结果。<br> OCR: 一个包含各个单行文字的检测坐标和识别结果的元组 |
运行完成后,每张图片会在`output`字段指定的目录下有一个同名目录,图片里的每个表格会存储为一个excel,图片区域会被裁剪之后保存下来,excel文件和图片名为表格在图片里的坐标。 运行完成后,每张图片会在`output`字段指定的目录下有一个同名目录,图片里的每个表格会存储为一个excel,图片区域会被裁剪之后保存下来,excel文件和图片名为表格在图片里的坐标。
...@@ -123,7 +181,7 @@ dict 里各个字段说明如下 ...@@ -123,7 +181,7 @@ dict 里各个字段说明如下
### 2.4 参数说明 ### 2.4 参数说明
| 字段 | 说明 | 默认值 | | 字段 | 说明 | 默认值 |
| --------------- | ---------------------------------------- | ------------------------------------------- | |----------------------|----------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------|
| output | excel和识别结果保存的地址 | ./output/table | | output | excel和识别结果保存的地址 | ./output/table |
| table_max_len | 表格结构模型预测时,图像的长边resize尺度 | 488 | | table_max_len | 表格结构模型预测时,图像的长边resize尺度 | 488 |
| table_model_dir | 表格结构模型 inference 模型地址 | None | | table_model_dir | 表格结构模型 inference 模型地址 | None |
...@@ -134,5 +192,8 @@ dict 里各个字段说明如下 ...@@ -134,5 +192,8 @@ dict 里各个字段说明如下
| max_seq_length | VQA SER模型最大支持token长度 | 512 | | max_seq_length | VQA SER模型最大支持token长度 | 512 |
| label_map_path | VQA SER 标签文件地址 | ./vqa/labels/labels_ser.txt | | label_map_path | VQA SER 标签文件地址 | ./vqa/labels/labels_ser.txt |
| mode | pipeline预测模式,structure: 版面分析+表格识别; VQA: SER文档信息抽取 | structure | | mode | pipeline预测模式,structure: 版面分析+表格识别; VQA: SER文档信息抽取 | structure |
| layout | 前向中是否执行版面分析 | True |
| table | 前向中是否执行表格识别 | True |
| ocr | 对于版面分析中的非表格区域,是否执行ocr。当layout为False时会被自动设置为False | True |
大部分参数和PaddleOCR whl包保持一致,见 [whl包文档](../../doc/doc_ch/whl.md) 大部分参数和PaddleOCR whl包保持一致,见 [whl包文档](../../doc/doc_ch/whl.md)
...@@ -23,9 +23,10 @@ sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) ...@@ -23,9 +23,10 @@ sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth' os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import cv2 import cv2
import json import json
import numpy as np
import time import time
import logging import logging
from copy import deepcopy
from attrdict import AttrDict
from ppocr.utils.utility import get_image_file_list, check_and_read_gif from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
...@@ -40,16 +41,18 @@ class StructureSystem(object): ...@@ -40,16 +41,18 @@ class StructureSystem(object):
def __init__(self, args): def __init__(self, args):
self.mode = args.mode self.mode = args.mode
if self.mode == 'structure': if self.mode == 'structure':
import layoutparser as lp
# args.det_limit_type = 'resize_long'
args.drop_score = 0
if not args.show_log: if not args.show_log:
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
self.text_system = TextSystem(args) if args.layout == False and args.ocr == True:
self.table_system = TableSystem(args, args.ocr = False
self.text_system.text_detector, logger.warning(
self.text_system.text_recognizer) "When args.layout is false, args.ocr is automatically set to false"
)
args.drop_score = 0
# init layout and ocr model
self.text_system = None
if args.layout:
import layoutparser as lp
config_path = None config_path = None
model_path = None model_path = None
if os.path.isdir(args.layout_path_model): if os.path.isdir(args.layout_path_model):
...@@ -64,29 +67,50 @@ class StructureSystem(object): ...@@ -64,29 +67,50 @@ class StructureSystem(object):
enable_mkldnn=args.enable_mkldnn, enable_mkldnn=args.enable_mkldnn,
enforce_cpu=not args.use_gpu, enforce_cpu=not args.use_gpu,
thread_num=args.cpu_threads) thread_num=args.cpu_threads)
self.use_angle_cls = args.use_angle_cls if args.ocr:
self.drop_score = args.drop_score self.text_system = TextSystem(args)
else:
self.table_layout = None
if args.table:
if self.text_system is not None:
self.table_system = TableSystem(
args, self.text_system.text_detector,
self.text_system.text_recognizer)
else:
self.table_system = TableSystem(args)
else:
self.table_system = None
elif self.mode == 'vqa': elif self.mode == 'vqa':
raise NotImplementedError raise NotImplementedError
def __call__(self, img): def __call__(self, img, return_ocr_result_in_table=False):
if self.mode == 'structure': if self.mode == 'structure':
ori_im = img.copy() ori_im = img.copy()
if self.table_layout is not None:
layout_res = self.table_layout.detect(img[..., ::-1]) layout_res = self.table_layout.detect(img[..., ::-1])
else:
h, w = ori_im.shape[:2]
layout_res = [AttrDict(coordinates=[0, 0, w, h], type='Table')]
res_list = [] res_list = []
for region in layout_res: for region in layout_res:
res = ''
x1, y1, x2, y2 = region.coordinates x1, y1, x2, y2 = region.coordinates
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
roi_img = ori_im[y1:y2, x1:x2, :] roi_img = ori_im[y1:y2, x1:x2, :]
if region.type == 'Table': if region.type == 'Table':
res = self.table_system(roi_img) if self.table_system is not None:
res = self.table_system(roi_img,
return_ocr_result_in_table)
else: else:
if self.text_system is not None:
filter_boxes, filter_rec_res = self.text_system(roi_img) filter_boxes, filter_rec_res = self.text_system(roi_img)
# remove style char # remove style char
style_token = [ style_token = [
'<strike>', '<strike>', '<sup>', '</sub>', '<b>', '<strike>', '<strike>', '<sup>', '</sub>', '<b>',
'</b>', '<sub>', '</sup>', '<overline>', '</overline>', '</b>', '<sub>', '</sup>', '<overline>',
'<underline>', '</underline>', '<i>', '</i>' '</overline>', '<underline>', '</underline>', '<i>',
'</i>'
] ]
res = [] res = []
for box, rec_res in zip(filter_boxes, filter_rec_res): for box, rec_res in zip(filter_boxes, filter_rec_res):
...@@ -106,31 +130,33 @@ class StructureSystem(object): ...@@ -106,31 +130,33 @@ class StructureSystem(object):
'img': roi_img, 'img': roi_img,
'res': res 'res': res
}) })
return res_list
elif self.mode == 'vqa': elif self.mode == 'vqa':
raise NotImplementedError raise NotImplementedError
return res_list return None
def save_structure_res(res, save_folder, img_name): def save_structure_res(res, save_folder, img_name):
excel_save_folder = os.path.join(save_folder, img_name) excel_save_folder = os.path.join(save_folder, img_name)
os.makedirs(excel_save_folder, exist_ok=True) os.makedirs(excel_save_folder, exist_ok=True)
res_cp = deepcopy(res)
# save res # save res
with open( with open(
os.path.join(excel_save_folder, 'res.txt'), 'w', os.path.join(excel_save_folder, 'res.txt'), 'w',
encoding='utf8') as f: encoding='utf8') as f:
for region in res: for region in res_cp:
if region['type'] == 'Table': roi_img = region.pop('img')
f.write('{}\n'.format(json.dumps(region)))
if region['type'] == 'Table' and len(region[
'res']) > 0 and 'html' in region['res']:
excel_path = os.path.join(excel_save_folder, excel_path = os.path.join(excel_save_folder,
'{}.xlsx'.format(region['bbox'])) '{}.xlsx'.format(region['bbox']))
to_excel(region['res'], excel_path) to_excel(region['res']['html'], excel_path)
elif region['type'] == 'Figure': elif region['type'] == 'Figure':
roi_img = region['img']
img_path = os.path.join(excel_save_folder, img_path = os.path.join(excel_save_folder,
'{}.jpg'.format(region['bbox'])) '{}.jpg'.format(region['bbox']))
cv2.imwrite(img_path, roi_img) cv2.imwrite(img_path, roi_img)
else:
for text_result in region['res']:
f.write('{}\n'.format(json.dumps(text_result)))
def main(args): def main(args):
......
...@@ -51,7 +51,7 @@ wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_tab ...@@ -51,7 +51,7 @@ wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_tab
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
cd .. cd ..
# run # run
python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/table.jpg --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --det_limit_side_len=736 --det_limit_type=min --output ../output/table python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=./docs/table/table.jpg --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --det_limit_side_len=736 --det_limit_type=min --output ./output/table
``` ```
Note: The above model is trained on the PubLayNet dataset and only supports English scanning scenarios. If you need to identify other scenarios, you need to train the model yourself and replace the three fields `det_model_dir`, `rec_model_dir`, `table_model_dir`. Note: The above model is trained on the PubLayNet dataset and only supports English scanning scenarios. If you need to identify other scenarios, you need to train the model yourself and replace the three fields `det_model_dir`, `rec_model_dir`, `table_model_dir`.
......
...@@ -61,7 +61,7 @@ wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_tab ...@@ -61,7 +61,7 @@ wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_tab
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
cd .. cd ..
# 执行预测 # 执行预测
python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/table.jpg --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --det_limit_side_len=736 --det_limit_type=min --output ../output/table python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=./docs/table/table.jpg --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --det_limit_side_len=736 --det_limit_type=min --output ./output/table
``` ```
运行完成后,每张图片的excel表格会保存到output字段指定的目录下 运行完成后,每张图片的excel表格会保存到output字段指定的目录下
......
...@@ -54,16 +54,20 @@ def expand(pix, det_box, shape): ...@@ -54,16 +54,20 @@ def expand(pix, det_box, shape):
class TableSystem(object): class TableSystem(object):
def __init__(self, args, text_detector=None, text_recognizer=None): def __init__(self, args, text_detector=None, text_recognizer=None):
self.text_detector = predict_det.TextDetector(args) if text_detector is None else text_detector self.text_detector = predict_det.TextDetector(
self.text_recognizer = predict_rec.TextRecognizer(args) if text_recognizer is None else text_recognizer args) if text_detector is None else text_detector
self.text_recognizer = predict_rec.TextRecognizer(
args) if text_recognizer is None else text_recognizer
self.table_structurer = predict_strture.TableStructurer(args) self.table_structurer = predict_strture.TableStructurer(args)
def __call__(self, img): def __call__(self, img, return_ocr_result_in_table=False):
result = dict()
ori_im = img.copy() ori_im = img.copy()
structure_res, elapse = self.table_structurer(copy.deepcopy(img)) structure_res, elapse = self.table_structurer(copy.deepcopy(img))
dt_boxes, elapse = self.text_detector(copy.deepcopy(img)) dt_boxes, elapse = self.text_detector(copy.deepcopy(img))
dt_boxes = sorted_boxes(dt_boxes) dt_boxes = sorted_boxes(dt_boxes)
if return_ocr_result_in_table:
result['boxes'] = [x.tolist() for x in dt_boxes]
r_boxes = [] r_boxes = []
for box in dt_boxes: for box in dt_boxes:
x_min = box[:, 0].min() - 1 x_min = box[:, 0].min() - 1
...@@ -88,14 +92,17 @@ class TableSystem(object): ...@@ -88,14 +92,17 @@ class TableSystem(object):
rec_res, elapse = self.text_recognizer(img_crop_list) rec_res, elapse = self.text_recognizer(img_crop_list)
logger.debug("rec_res num : {}, elapse : {}".format( logger.debug("rec_res num : {}, elapse : {}".format(
len(rec_res), elapse)) len(rec_res), elapse))
if return_ocr_result_in_table:
result['rec_res'] = rec_res
pred_html, pred = self.rebuild_table(structure_res, dt_boxes, rec_res) pred_html, pred = self.rebuild_table(structure_res, dt_boxes, rec_res)
return pred_html result['html'] = pred_html
return result
def rebuild_table(self, structure_res, dt_boxes, rec_res): def rebuild_table(self, structure_res, dt_boxes, rec_res):
pred_structures, pred_bboxes = structure_res pred_structures, pred_bboxes = structure_res
matched_index = self.match_result(dt_boxes, pred_bboxes) matched_index = self.match_result(dt_boxes, pred_bboxes)
pred_html, pred = self.get_pred_html(pred_structures, matched_index, rec_res) pred_html, pred = self.get_pred_html(pred_structures, matched_index,
rec_res)
return pred_html, pred return pred_html, pred
def match_result(self, dt_boxes, pred_bboxes): def match_result(self, dt_boxes, pred_bboxes):
...@@ -104,11 +111,13 @@ class TableSystem(object): ...@@ -104,11 +111,13 @@ class TableSystem(object):
# gt_box = [np.min(gt_box[:, 0]), np.min(gt_box[:, 1]), np.max(gt_box[:, 0]), np.max(gt_box[:, 1])] # gt_box = [np.min(gt_box[:, 0]), np.min(gt_box[:, 1]), np.max(gt_box[:, 0]), np.max(gt_box[:, 1])]
distances = [] distances = []
for j, pred_box in enumerate(pred_bboxes): for j, pred_box in enumerate(pred_bboxes):
distances.append( distances.append((distance(gt_box, pred_box),
(distance(gt_box, pred_box), 1. - compute_iou(gt_box, pred_box))) # 获取两两cell之间的L1距离和 1- IOU 1. - compute_iou(gt_box, pred_box)
)) # 获取两两cell之间的L1距离和 1- IOU
sorted_distances = distances.copy() sorted_distances = distances.copy()
# 根据距离和IOU挑选最"近"的cell # 根据距离和IOU挑选最"近"的cell
sorted_distances = sorted(sorted_distances, key=lambda item: (item[1], item[0])) sorted_distances = sorted(
sorted_distances, key=lambda item: (item[1], item[0]))
if distances.index(sorted_distances[0]) not in matched.keys(): if distances.index(sorted_distances[0]) not in matched.keys():
matched[distances.index(sorted_distances[0])] = [i] matched[distances.index(sorted_distances[0])] = [i]
else: else:
...@@ -122,7 +131,8 @@ class TableSystem(object): ...@@ -122,7 +131,8 @@ class TableSystem(object):
if '</td>' in tag: if '</td>' in tag:
if td_index in matched_index.keys(): if td_index in matched_index.keys():
b_with = False b_with = False
if '<b>' in ocr_contents[matched_index[td_index][0]] and len(matched_index[td_index]) > 1: if '<b>' in ocr_contents[matched_index[td_index][
0]] and len(matched_index[td_index]) > 1:
b_with = True b_with = True
end_html.extend('<b>') end_html.extend('<b>')
for i, td_index_index in enumerate(matched_index[td_index]): for i, td_index_index in enumerate(matched_index[td_index]):
...@@ -138,7 +148,8 @@ class TableSystem(object): ...@@ -138,7 +148,8 @@ class TableSystem(object):
content = content[:-4] content = content[:-4]
if len(content) == 0: if len(content) == 0:
continue continue
if i != len(matched_index[td_index]) - 1 and ' ' != content[-1]: if i != len(matched_index[
td_index]) - 1 and ' ' != content[-1]:
content += ' ' content += ' '
end_html.extend(content) end_html.extend(content)
if b_with: if b_with:
...@@ -187,18 +198,19 @@ def main(args): ...@@ -187,18 +198,19 @@ def main(args):
for i, image_file in enumerate(image_file_list): for i, image_file in enumerate(image_file_list):
logger.info("[{}/{}] {}".format(i, img_num, image_file)) logger.info("[{}/{}] {}".format(i, img_num, image_file))
img, flag = check_and_read_gif(image_file) img, flag = check_and_read_gif(image_file)
excel_path = os.path.join(args.output, os.path.basename(image_file).split('.')[0] + '.xlsx') excel_path = os.path.join(
args.output, os.path.basename(image_file).split('.')[0] + '.xlsx')
if not flag: if not flag:
img = cv2.imread(image_file) img = cv2.imread(image_file)
if img is None: if img is None:
logger.error("error in loading image:{}".format(image_file)) logger.error("error in loading image:{}".format(image_file))
continue continue
starttime = time.time() starttime = time.time()
pred_html = text_sys(img) pred_res = text_sys(img)
pred_html = pred_res['html']
logger.info(pred_html)
to_excel(pred_html, excel_path) to_excel(pred_html, excel_path)
logger.info('excel saved to {}'.format(excel_path)) logger.info('excel saved to {}'.format(excel_path))
logger.info(pred_html)
elapse = time.time() - starttime elapse = time.time() - starttime
logger.info("Predict time : {:.3f}s".format(elapse)) logger.info("Predict time : {:.3f}s".format(elapse))
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import ast import ast
from PIL import Image from PIL import Image
import numpy as np import numpy as np
from tools.infer.utility import draw_ocr_box_txt, init_args as infer_args from tools.infer.utility import draw_ocr_box_txt, str2bool, init_args as infer_args
def init_args(): def init_args():
...@@ -30,6 +30,7 @@ def init_args(): ...@@ -30,6 +30,7 @@ def init_args():
"--table_char_dict_path", "--table_char_dict_path",
type=str, type=str,
default="../ppocr/utils/dict/table_structure_dict.txt") default="../ppocr/utils/dict/table_structure_dict.txt")
# params for layout
parser.add_argument( parser.add_argument(
"--layout_path_model", "--layout_path_model",
type=str, type=str,
...@@ -39,11 +40,27 @@ def init_args(): ...@@ -39,11 +40,27 @@ def init_args():
type=ast.literal_eval, type=ast.literal_eval,
default=None, default=None,
help='label map according to ppstructure/layout/README_ch.md') help='label map according to ppstructure/layout/README_ch.md')
# params for inference
parser.add_argument( parser.add_argument(
"--mode", "--mode",
type=str, type=str,
default='structure', default='structure',
help='structure and vqa is supported') help='structure and vqa is supported')
parser.add_argument(
"--layout",
type=str2bool,
default=True,
help='Whether to enable layout analysis')
parser.add_argument(
"--table",
type=str2bool,
default=True,
help='In the forward, whether the table area uses table recognition')
parser.add_argument(
"--ocr",
type=str2bool,
default=True,
help='In the forward, whether the non-table area is recognition by ocr')
return parser return parser
......
...@@ -12,3 +12,4 @@ cython ...@@ -12,3 +12,4 @@ cython
lxml lxml
premailer premailer
openpyxl openpyxl
attrdict
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册