未验证 提交 87fcbf6b 编写于 作者: Z zhoujun 提交者: GitHub

Merge pull request #7859 from WenmuZhou/tipc_2

add re to ppstructure system
...@@ -567,6 +567,7 @@ class PPStructure(StructureSystem): ...@@ -567,6 +567,7 @@ class PPStructure(StructureSystem):
assert params.structure_version in SUPPORT_STRUCTURE_MODEL_VERSION, "structure_version must in {}, but get {}".format( assert params.structure_version in SUPPORT_STRUCTURE_MODEL_VERSION, "structure_version must in {}, but get {}".format(
SUPPORT_STRUCTURE_MODEL_VERSION, params.structure_version) SUPPORT_STRUCTURE_MODEL_VERSION, params.structure_version)
params.use_gpu = check_gpu(params.use_gpu) params.use_gpu = check_gpu(params.use_gpu)
params.mode = 'structure'
if not params.show_log: if not params.show_log:
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
......
# 基于Python预测引擎推理 # 基于Python预测引擎推理
- [1. 版面信息抽取](#1) - [1. 版面信息抽取](#1-版面信息抽取)
- [1.1 版面分析+表格识别](#1.1) - [1.1 版面分析+表格识别](#11-版面分析表格识别)
- [1.2 版面分析](#1.2) - [1.2 版面分析](#12-版面分析)
- [1.3 表格识别](#1.3) - [1.3 表格识别](#13-表格识别)
- [2. 关键信息抽取](#2) - [2. 关键信息抽取](#2-关键信息抽取)
- [2.1 SER](#21-ser)
- [2.2 RE+SER](#22-reser)
<a name="1"></a> <a name="1"></a>
## 1. 版面信息抽取 ## 1. 版面信息抽取
...@@ -70,6 +72,8 @@ python3 predict_system.py --det_model_dir=inference/ch_PP-OCRv3_det_infer \ ...@@ -70,6 +72,8 @@ python3 predict_system.py --det_model_dir=inference/ch_PP-OCRv3_det_infer \
<a name="2"></a> <a name="2"></a>
## 2. 关键信息抽取 ## 2. 关键信息抽取
### 2.1 SER
```bash ```bash
cd ppstructure cd ppstructure
...@@ -77,13 +81,38 @@ mkdir inference && cd inference ...@@ -77,13 +81,38 @@ mkdir inference && cd inference
# 下载SER XFUND 模型并解压 # 下载SER XFUND 模型并解压
wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar && tar -xf ser_vi_layoutxlm_xfund_infer.tar wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar && tar -xf ser_vi_layoutxlm_xfund_infer.tar
cd .. cd ..
python3 kie/predict_kie_token_ser.py \ python3 predict_system.py \
--kie_algorithm=LayoutXLM \ --kie_algorithm=LayoutXLM \
--ser_model_dir=../inference/ser_vi_layoutxlm_xfund_infer \ --ser_model_dir=./inference/ser_vi_layoutxlm_xfund_infer \
--image_dir=./docs/kie/input/zh_val_42.jpg \ --image_dir=./docs/kie/input/zh_val_42.jpg \
--ser_dict_path=../ppocr/utils/dict/kie_dict/xfund_class_list.txt \ --ser_dict_path=../ppocr/utils/dict/kie_dict/xfund_class_list.txt \
--vis_font_path=../doc/fonts/simfang.ttf \ --vis_font_path=../doc/fonts/simfang.ttf \
--ocr_order_method="tb-yx" --ocr_order_method="tb-yx" \
--mode=kie
``` ```
运行完成后,每张图片会在`output`字段指定的目录下的`kie`目录下存放可视化之后的图片,图片名和输入图片名一致。 运行完成后,每张图片会在`output`字段指定的目录下的`kie`目录下存放可视化之后的图片,图片名和输入图片名一致。
### 2.2 RE+SER
```bash
cd ppstructure
mkdir inference && cd inference
# 下载RE SER XFUND 模型并解压
wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar && tar -xf ser_vi_layoutxlm_xfund_infer.tar
wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_infer.tar && tar -xf re_vi_layoutxlm_xfund_infer.tar
cd ..
python3 predict_system.py \
--kie_algorithm=LayoutXLM \
--re_model_dir=./inference/re_vi_layoutxlm_xfund_infer \
--ser_model_dir=./inference/ser_vi_layoutxlm_xfund_infer \
--image_dir=./docs/kie/input/zh_val_42.jpg \
--ser_dict_path=../ppocr/utils/dict/kie_dict/xfund_class_list.txt \
--vis_font_path=../doc/fonts/simfang.ttf \
--ocr_order_method="tb-yx" \
--mode=kie
```
运行完成后,每张图片会在`output`字段指定的目录下的`kie`目录下有一个同名目录,目录中存放可视化图片和预测结果。
# Python Inference # Python Inference
- [1. Layout Structured Analysis](#1) - [1. Layout Structured Analysis](#1-layout-structured-analysis)
- [1.1 layout analysis + table recognition](#1.1) - [1.1 layout analysis + table recognition](#11-layout-analysis--table-recognition)
- [1.2 layout analysis](#1.2) - [1.2 layout analysis](#12-layout-analysis)
- [1.3 table recognition](#1.3) - [1.3 table recognition](#13-table-recognition)
- [2. Key Information Extraction](#2) - [2. Key Information Extraction](#2-key-information-extraction)
- [2.1 SER](#21-ser)
- [2.2 RE+SER](#22-reser)
<a name="1"></a> <a name="1"></a>
## 1. Layout Structured Analysis ## 1. Layout Structured Analysis
...@@ -72,6 +74,7 @@ After the operation is completed, each image will have a directory with the same ...@@ -72,6 +74,7 @@ After the operation is completed, each image will have a directory with the same
<a name="2"></a> <a name="2"></a>
## 2. Key Information Extraction ## 2. Key Information Extraction
### 2.1 SER
```bash ```bash
cd ppstructure cd ppstructure
...@@ -79,13 +82,39 @@ mkdir inference && cd inference ...@@ -79,13 +82,39 @@ mkdir inference && cd inference
# download model # download model
wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar && tar -xf ser_vi_layoutxlm_xfund_infer.tar wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar && tar -xf ser_vi_layoutxlm_xfund_infer.tar
cd .. cd ..
python3 kie/predict_kie_token_ser.py \ python3 predict_system.py \
--kie_algorithm=LayoutXLM \ --kie_algorithm=LayoutXLM \
--ser_model_dir=../inference/ser_vi_layoutxlm_xfund_infer \ --ser_model_dir=./inference/ser_vi_layoutxlm_xfund_infer \
--image_dir=./docs/kie/input/zh_val_42.jpg \ --image_dir=./docs/kie/input/zh_val_42.jpg \
--ser_dict_path=../ppocr/utils/dict/kie_dict/xfund_class_list.txt \ --ser_dict_path=../ppocr/utils/dict/kie_dict/xfund_class_list.txt \
--vis_font_path=../doc/fonts/simfang.ttf \ --vis_font_path=../doc/fonts/simfang.ttf \
--ocr_order_method="tb-yx" --ocr_order_method="tb-yx" \
--mode=kie
``` ```
After the operation is completed, each image will store the visualized image in the `kie` directory under the directory specified by the `output` field, and the image name is the same as the input image name. After the operation is completed, each image will store the visualized image in the `kie` directory under the directory specified by the `output` field, and the image name is the same as the input image name.
### 2.2 RE+SER
```bash
cd ppstructure
mkdir inference && cd inference
# download model
wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar && tar -xf ser_vi_layoutxlm_xfund_infer.tar
wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_infer.tar && tar -xf re_vi_layoutxlm_xfund_infer.tar
cd ..
python3 predict_system.py \
--kie_algorithm=LayoutXLM \
--re_model_dir=./inference/re_vi_layoutxlm_xfund_infer \
--ser_model_dir=./inference/ser_vi_layoutxlm_xfund_infer \
--image_dir=./docs/kie/input/zh_val_42.jpg \
--ser_dict_path=../ppocr/utils/dict/kie_dict/xfund_class_list.txt \
--vis_font_path=../doc/fonts/simfang.ttf \
--ocr_order_method="tb-yx" \
--mode=kie
```
After the operation is completed, each image will have a directory with the same name in the `kie` directory under the directory specified by the `output` field, where the visual images and prediction results are stored.
...@@ -29,13 +29,11 @@ import tools.infer.utility as utility ...@@ -29,13 +29,11 @@ import tools.infer.utility as utility
from tools.infer_kie_token_ser_re import make_input from tools.infer_kie_token_ser_re import make_input
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
from ppocr.utils.visual import draw_re_results from ppocr.utils.visual import draw_ser_results, draw_re_results
from ppocr.utils.utility import get_image_file_list, check_and_read from ppocr.utils.utility import get_image_file_list, check_and_read
from ppstructure.utility import parse_args from ppstructure.utility import parse_args
from ppstructure.kie.predict_kie_token_ser import SerPredictor from ppstructure.kie.predict_kie_token_ser import SerPredictor
from paddleocr import PaddleOCR
logger = get_logger() logger = get_logger()
...@@ -43,15 +41,20 @@ class SerRePredictor(object): ...@@ -43,15 +41,20 @@ class SerRePredictor(object):
def __init__(self, args): def __init__(self, args):
self.use_visual_backbone = args.use_visual_backbone self.use_visual_backbone = args.use_visual_backbone
self.ser_engine = SerPredictor(args) self.ser_engine = SerPredictor(args)
if args.re_model_dir is not None:
postprocess_params = {'name': 'VQAReTokenLayoutLMPostProcess'} postprocess_params = {'name': 'VQAReTokenLayoutLMPostProcess'}
self.postprocess_op = build_post_process(postprocess_params) self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors, self.config = \ self.predictor, self.input_tensor, self.output_tensors, self.config = \
utility.create_predictor(args, 're', logger) utility.create_predictor(args, 're', logger)
else:
self.predictor = None
def __call__(self, img): def __call__(self, img):
starttime = time.time() starttime = time.time()
ser_results, ser_inputs, _ = self.ser_engine(img) ser_results, ser_inputs, ser_elapse = self.ser_engine(img)
if self.predictor is None:
return ser_results, ser_elapse
re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results) re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results)
if self.use_visual_backbone == False: if self.use_visual_backbone == False:
re_input.pop(4) re_input.pop(4)
...@@ -79,7 +82,7 @@ class SerRePredictor(object): ...@@ -79,7 +82,7 @@ class SerRePredictor(object):
def main(args): def main(args):
image_file_list = get_image_file_list(args.image_dir) image_file_list = get_image_file_list(args.image_dir)
ser_predictor = SerRePredictor(args) ser_re_predictor = SerRePredictor(args)
count = 0 count = 0
total_time = 0 total_time = 0
...@@ -95,7 +98,7 @@ def main(args): ...@@ -95,7 +98,7 @@ def main(args):
if img is None: if img is None:
logger.info("error in loading image:{}".format(image_file)) logger.info("error in loading image:{}".format(image_file))
continue continue
re_res, elapse = ser_predictor(img) re_res, elapse = ser_re_predictor(img)
re_res = re_res[0] re_res = re_res[0]
res_str = '{}\t{}\n'.format( res_str = '{}\t{}\n'.format(
...@@ -105,14 +108,20 @@ def main(args): ...@@ -105,14 +108,20 @@ def main(args):
"ocr_info": re_res, "ocr_info": re_res,
}, ensure_ascii=False)) }, ensure_ascii=False))
f_w.write(res_str) f_w.write(res_str)
if ser_re_predictor.predictor is not None:
img_res = draw_re_results( img_res = draw_re_results(
image_file, re_res, font_path=args.vis_font_path) image_file, re_res, font_path=args.vis_font_path)
img_save_path = os.path.join(
img_save_path = os.path.join( args.output,
args.output, os.path.splitext(os.path.basename(image_file))[0] +
os.path.splitext(os.path.basename(image_file))[0] + "_ser_re.jpg")
"_ser_re.jpg") else:
img_res = draw_ser_results(
image_file, re_res, font_path=args.vis_font_path)
img_save_path = os.path.join(
args.output,
os.path.splitext(os.path.basename(image_file))[0] +
"_ser.jpg")
cv2.imwrite(img_save_path, img_res) cv2.imwrite(img_save_path, img_res)
logger.info("save vis result to {}".format(img_save_path)) logger.info("save vis result to {}".format(img_save_path))
......
...@@ -30,6 +30,7 @@ from copy import deepcopy ...@@ -30,6 +30,7 @@ from copy import deepcopy
from ppocr.utils.utility import get_image_file_list, check_and_read from ppocr.utils.utility import get_image_file_list, check_and_read
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
from ppocr.utils.visual import draw_ser_results, draw_re_results
from tools.infer.predict_system import TextSystem from tools.infer.predict_system import TextSystem
from ppstructure.layout.predict_layout import LayoutPredictor from ppstructure.layout.predict_layout import LayoutPredictor
from ppstructure.table.predict_table import TableSystem, to_excel from ppstructure.table.predict_table import TableSystem, to_excel
...@@ -75,7 +76,8 @@ class StructureSystem(object): ...@@ -75,7 +76,8 @@ class StructureSystem(object):
self.table_system = TableSystem(args) self.table_system = TableSystem(args)
elif self.mode == 'kie': elif self.mode == 'kie':
raise NotImplementedError from ppstructure.kie.predict_kie_token_ser_re import SerRePredictor
self.kie_predictor = SerRePredictor(args)
def __call__(self, img, return_ocr_result_in_table=False, img_idx=0): def __call__(self, img, return_ocr_result_in_table=False, img_idx=0):
time_dict = { time_dict = {
...@@ -176,7 +178,10 @@ class StructureSystem(object): ...@@ -176,7 +178,10 @@ class StructureSystem(object):
time_dict['all'] = end - start time_dict['all'] = end - start
return res_list, time_dict return res_list, time_dict
elif self.mode == 'kie': elif self.mode == 'kie':
raise NotImplementedError re_res, elapse = self.kie_predictor(img)
time_dict['kie'] = elapse
time_dict['all'] = elapse
return re_res[0], time_dict
return None, None return None, None
...@@ -235,15 +240,32 @@ def main(args): ...@@ -235,15 +240,32 @@ def main(args):
all_res = [] all_res = []
for index, img in enumerate(imgs): for index, img in enumerate(imgs):
res, time_dict = structure_sys(img, img_idx=index) res, time_dict = structure_sys(img, img_idx=index)
img_save_path = os.path.join(save_folder, img_name,
'show_{}.jpg'.format(index))
os.makedirs(os.path.join(save_folder, img_name), exist_ok=True)
if structure_sys.mode == 'structure' and res != []: if structure_sys.mode == 'structure' and res != []:
save_structure_res(res, save_folder, img_name, index)
draw_img = draw_structure_result(img, res, args.vis_font_path) draw_img = draw_structure_result(img, res, args.vis_font_path)
img_save_path = os.path.join(save_folder, img_name, save_structure_res(res, save_folder, img_name, index)
'show_{}.jpg'.format(index))
elif structure_sys.mode == 'kie': elif structure_sys.mode == 'kie':
raise NotImplementedError if structure_sys.kie_predictor.predictor is not None:
# draw_img = draw_ser_results(img, res, args.vis_font_path) draw_img = draw_re_results(
# img_save_path = os.path.join(save_folder, img_name + '.jpg') img, res, font_path=args.vis_font_path)
else:
draw_img = draw_ser_results(
img, res, font_path=args.vis_font_path)
with open(
os.path.join(save_folder, img_name,
'res_{}_kie.txt'.format(index)),
'w',
encoding='utf8') as f:
res_str = '{}\t{}\n'.format(
image_file,
json.dumps(
{
"ocr_info": res
}, ensure_ascii=False))
f.write(res_str)
if res != []: if res != []:
cv2.imwrite(img_save_path, draw_img) cv2.imwrite(img_save_path, draw_img)
logger.info('result save to {}'.format(img_save_path)) logger.info('result save to {}'.format(img_save_path))
......
...@@ -11,9 +11,9 @@ ...@@ -11,9 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import random
import ast import ast
from PIL import Image from PIL import Image, ImageDraw, ImageFont
import numpy as np import numpy as np
from tools.infer.utility import draw_ocr_box_txt, str2bool, init_args as infer_args from tools.infer.utility import draw_ocr_box_txt, str2bool, init_args as infer_args
...@@ -64,6 +64,7 @@ def init_args(): ...@@ -64,6 +64,7 @@ def init_args():
parser.add_argument( parser.add_argument(
"--mode", "--mode",
type=str, type=str,
choices=['structure', 'kie'],
default='structure', default='structure',
help='structure and kie is supported') help='structure and kie is supported')
parser.add_argument( parser.add_argument(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册