提交 d4a4c07c 编写于 作者: 文幕地方's avatar 文幕地方

add ser to ppstructure system

上级 c647a6da
...@@ -81,13 +81,14 @@ mkdir inference && cd inference ...@@ -81,13 +81,14 @@ mkdir inference && cd inference
# 下载SER XFUND 模型并解压 # 下载SER XFUND 模型并解压
wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar && tar -xf ser_vi_layoutxlm_xfund_infer.tar wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar && tar -xf ser_vi_layoutxlm_xfund_infer.tar
cd .. cd ..
python3 kie/predict_kie_token_ser.py \ python3 predict_system.py \
--kie_algorithm=LayoutXLM \ --kie_algorithm=LayoutXLM \
--ser_model_dir=../inference/ser_vi_layoutxlm_xfund_infer \ --ser_model_dir=./inference/ser_vi_layoutxlm_xfund_infer \
--image_dir=./docs/kie/input/zh_val_42.jpg \ --image_dir=./docs/kie/input/zh_val_42.jpg \
--ser_dict_path=../ppocr/utils/dict/kie_dict/xfund_class_list.txt \ --ser_dict_path=../ppocr/utils/dict/kie_dict/xfund_class_list.txt \
--vis_font_path=../doc/fonts/simfang.ttf \ --vis_font_path=../doc/fonts/simfang.ttf \
--ocr_order_method="tb-yx" --ocr_order_method="tb-yx" \
--mode=kie
``` ```
运行完成后,每张图片会在`output`字段指定的目录下的`kie`目录下存放可视化之后的图片,图片名和输入图片名一致。 运行完成后,每张图片会在`output`字段指定的目录下的`kie`目录下存放可视化之后的图片,图片名和输入图片名一致。
......
...@@ -82,13 +82,14 @@ mkdir inference && cd inference ...@@ -82,13 +82,14 @@ mkdir inference && cd inference
# download model # download model
wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar && tar -xf ser_vi_layoutxlm_xfund_infer.tar wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar && tar -xf ser_vi_layoutxlm_xfund_infer.tar
cd .. cd ..
python3 kie/predict_kie_token_ser.py \ python3 predict_system.py \
--kie_algorithm=LayoutXLM \ --kie_algorithm=LayoutXLM \
--ser_model_dir=../inference/ser_vi_layoutxlm_xfund_infer \ --ser_model_dir=./inference/ser_vi_layoutxlm_xfund_infer \
--image_dir=./docs/kie/input/zh_val_42.jpg \ --image_dir=./docs/kie/input/zh_val_42.jpg \
--ser_dict_path=../ppocr/utils/dict/kie_dict/xfund_class_list.txt \ --ser_dict_path=../ppocr/utils/dict/kie_dict/xfund_class_list.txt \
--vis_font_path=../doc/fonts/simfang.ttf \ --vis_font_path=../doc/fonts/simfang.ttf \
--ocr_order_method="tb-yx" --ocr_order_method="tb-yx" \
--mode=kie
``` ```
After the operation is completed, each image will store the visualized image in the `kie` directory under the directory specified by the `output` field, and the image name is the same as the input image name. After the operation is completed, each image will store the visualized image in the `kie` directory under the directory specified by the `output` field, and the image name is the same as the input image name.
......
...@@ -29,7 +29,7 @@ import tools.infer.utility as utility ...@@ -29,7 +29,7 @@ import tools.infer.utility as utility
from tools.infer_kie_token_ser_re import make_input from tools.infer_kie_token_ser_re import make_input
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
from ppocr.utils.visual import draw_re_results from ppocr.utils.visual import draw_ser_results, draw_re_results
from ppocr.utils.utility import get_image_file_list, check_and_read from ppocr.utils.utility import get_image_file_list, check_and_read
from ppstructure.utility import parse_args from ppstructure.utility import parse_args
from ppstructure.kie.predict_kie_token_ser import SerPredictor from ppstructure.kie.predict_kie_token_ser import SerPredictor
...@@ -41,15 +41,20 @@ class SerRePredictor(object): ...@@ -41,15 +41,20 @@ class SerRePredictor(object):
def __init__(self, args): def __init__(self, args):
self.use_visual_backbone = args.use_visual_backbone self.use_visual_backbone = args.use_visual_backbone
self.ser_engine = SerPredictor(args) self.ser_engine = SerPredictor(args)
if args.re_model_dir is not None:
postprocess_params = {'name': 'VQAReTokenLayoutLMPostProcess'} postprocess_params = {'name': 'VQAReTokenLayoutLMPostProcess'}
self.postprocess_op = build_post_process(postprocess_params) self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors, self.config = \ self.predictor, self.input_tensor, self.output_tensors, self.config = \
utility.create_predictor(args, 're', logger) utility.create_predictor(args, 're', logger)
else:
self.predictor = None
def __call__(self, img): def __call__(self, img):
starttime = time.time() starttime = time.time()
ser_results, ser_inputs, _ = self.ser_engine(img) ser_results, ser_inputs, ser_elapse = self.ser_engine(img)
if self.predictor is None:
return ser_results, ser_elapse
re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results) re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results)
if self.use_visual_backbone == False: if self.use_visual_backbone == False:
re_input.pop(4) re_input.pop(4)
...@@ -77,7 +82,7 @@ class SerRePredictor(object): ...@@ -77,7 +82,7 @@ class SerRePredictor(object):
def main(args): def main(args):
image_file_list = get_image_file_list(args.image_dir) image_file_list = get_image_file_list(args.image_dir)
ser_predictor = SerRePredictor(args) ser_re_predictor = SerRePredictor(args)
count = 0 count = 0
total_time = 0 total_time = 0
...@@ -93,7 +98,7 @@ def main(args): ...@@ -93,7 +98,7 @@ def main(args):
if img is None: if img is None:
logger.info("error in loading image:{}".format(image_file)) logger.info("error in loading image:{}".format(image_file))
continue continue
re_res, elapse = ser_predictor(img) re_res, elapse = ser_re_predictor(img)
re_res = re_res[0] re_res = re_res[0]
res_str = '{}\t{}\n'.format( res_str = '{}\t{}\n'.format(
...@@ -103,14 +108,20 @@ def main(args): ...@@ -103,14 +108,20 @@ def main(args):
"ocr_info": re_res, "ocr_info": re_res,
}, ensure_ascii=False)) }, ensure_ascii=False))
f_w.write(res_str) f_w.write(res_str)
if ser_re_predictor.predictor is not None:
img_res = draw_re_results( img_res = draw_re_results(
image_file, re_res, font_path=args.vis_font_path) image_file, re_res, font_path=args.vis_font_path)
img_save_path = os.path.join( img_save_path = os.path.join(
args.output, args.output,
os.path.splitext(os.path.basename(image_file))[0] + os.path.splitext(os.path.basename(image_file))[0] +
"_ser_re.jpg") "_ser_re.jpg")
else:
img_res = draw_ser_results(
image_file, re_res, font_path=args.vis_font_path)
img_save_path = os.path.join(
args.output,
os.path.splitext(os.path.basename(image_file))[0] +
"_ser.jpg")
cv2.imwrite(img_save_path, img_res) cv2.imwrite(img_save_path, img_res)
logger.info("save vis result to {}".format(img_save_path)) logger.info("save vis result to {}".format(img_save_path))
......
...@@ -30,7 +30,7 @@ from copy import deepcopy ...@@ -30,7 +30,7 @@ from copy import deepcopy
from ppocr.utils.utility import get_image_file_list, check_and_read from ppocr.utils.utility import get_image_file_list, check_and_read
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
from ppocr.utils.visual import draw_re_results from ppocr.utils.visual import draw_ser_results, draw_re_results
from tools.infer.predict_system import TextSystem from tools.infer.predict_system import TextSystem
from ppstructure.layout.predict_layout import LayoutPredictor from ppstructure.layout.predict_layout import LayoutPredictor
from ppstructure.table.predict_table import TableSystem, to_excel from ppstructure.table.predict_table import TableSystem, to_excel
...@@ -180,6 +180,7 @@ class StructureSystem(object): ...@@ -180,6 +180,7 @@ class StructureSystem(object):
elif self.mode == 'kie': elif self.mode == 'kie':
re_res, elapse = self.kie_predictor(img) re_res, elapse = self.kie_predictor(img)
time_dict['kie'] = elapse time_dict['kie'] = elapse
time_dict['all'] = elapse
return re_res[0], time_dict return re_res[0], time_dict
return None, None return None, None
...@@ -246,8 +247,12 @@ def main(args): ...@@ -246,8 +247,12 @@ def main(args):
draw_img = draw_structure_result(img, res, args.vis_font_path) draw_img = draw_structure_result(img, res, args.vis_font_path)
save_structure_res(res, save_folder, img_name, index) save_structure_res(res, save_folder, img_name, index)
elif structure_sys.mode == 'kie': elif structure_sys.mode == 'kie':
if structure_sys.kie_predictor.predictor is not None:
draw_img = draw_re_results( draw_img = draw_re_results(
img, res, font_path=args.vis_font_path) img, res, font_path=args.vis_font_path)
else:
draw_img = draw_ser_results(
img, res, font_path=args.vis_font_path)
with open( with open(
os.path.join(save_folder, img_name, os.path.join(save_folder, img_name,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册