diff --git a/paddleocr.py b/paddleocr.py
index 95a19147fee6dff30af2264d26aceac85b114289..f865bd08cbe231c4c39e81f131da012771280367 100644
--- a/paddleocr.py
+++ b/paddleocr.py
@@ -567,6 +567,7 @@ class PPStructure(StructureSystem):
assert params.structure_version in SUPPORT_STRUCTURE_MODEL_VERSION, "structure_version must in {}, but get {}".format(
SUPPORT_STRUCTURE_MODEL_VERSION, params.structure_version)
params.use_gpu = check_gpu(params.use_gpu)
+ params.mode = 'structure'
if not params.show_log:
logger.setLevel(logging.INFO)
diff --git a/ppstructure/docs/inference.md b/ppstructure/docs/inference.md
index 516db82784ce98abba6db14c795fe7323be508e0..7aa2fd0d99bebc2e2dddfc2bac0023f035a69a39 100644
--- a/ppstructure/docs/inference.md
+++ b/ppstructure/docs/inference.md
@@ -1,10 +1,12 @@
# 基于Python预测引擎推理
-- [1. 版面信息抽取](#1)
- - [1.1 版面分析+表格识别](#1.1)
- - [1.2 版面分析](#1.2)
- - [1.3 表格识别](#1.3)
-- [2. 关键信息抽取](#2)
+- [1. 版面信息抽取](#1-版面信息抽取)
+ - [1.1 版面分析+表格识别](#11-版面分析表格识别)
+ - [1.2 版面分析](#12-版面分析)
+ - [1.3 表格识别](#13-表格识别)
+- [2. 关键信息抽取](#2-关键信息抽取)
+ - [2.1 SER](#21-ser)
+ - [2.2 RE+SER](#22-reser)
## 1. 版面信息抽取
@@ -70,6 +72,8 @@ python3 predict_system.py --det_model_dir=inference/ch_PP-OCRv3_det_infer \
## 2. 关键信息抽取
+### 2.1 SER
+
```bash
cd ppstructure
@@ -77,13 +81,38 @@ mkdir inference && cd inference
# 下载SER XFUND 模型并解压
wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar && tar -xf ser_vi_layoutxlm_xfund_infer.tar
cd ..
-python3 kie/predict_kie_token_ser.py \
+python3 predict_system.py \
--kie_algorithm=LayoutXLM \
- --ser_model_dir=../inference/ser_vi_layoutxlm_xfund_infer \
+ --ser_model_dir=./inference/ser_vi_layoutxlm_xfund_infer \
--image_dir=./docs/kie/input/zh_val_42.jpg \
--ser_dict_path=../ppocr/utils/dict/kie_dict/xfund_class_list.txt \
--vis_font_path=../doc/fonts/simfang.ttf \
- --ocr_order_method="tb-yx"
+ --ocr_order_method="tb-yx" \
+ --mode=kie
```
运行完成后,每张图片会在`output`字段指定的目录下的`kie`目录下存放可视化之后的图片,图片名和输入图片名一致。
+
+### 2.2 RE+SER
+
+```bash
+cd ppstructure
+
+mkdir inference && cd inference
+# 下载RE SER XFUND 模型并解压
+wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar && tar -xf ser_vi_layoutxlm_xfund_infer.tar
+wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_infer.tar && tar -xf re_vi_layoutxlm_xfund_infer.tar
+cd ..
+
+python3 predict_system.py \
+ --kie_algorithm=LayoutXLM \
+ --re_model_dir=./inference/re_vi_layoutxlm_xfund_infer \
+ --ser_model_dir=./inference/ser_vi_layoutxlm_xfund_infer \
+ --image_dir=./docs/kie/input/zh_val_42.jpg \
+ --ser_dict_path=../ppocr/utils/dict/kie_dict/xfund_class_list.txt \
+ --vis_font_path=../doc/fonts/simfang.ttf \
+ --ocr_order_method="tb-yx" \
+ --mode=kie
+```
+
+运行完成后,每张图片会在`output`字段指定的目录下的`kie`目录下有一个同名目录,目录中存放可视化图片和预测结果。
diff --git a/ppstructure/docs/inference_en.md b/ppstructure/docs/inference_en.md
index 71019ec70f80e44bc16d2b0d07b0bb93b475b7e7..1bb683a684b58f6b3aa12a7be0be824031de361b 100644
--- a/ppstructure/docs/inference_en.md
+++ b/ppstructure/docs/inference_en.md
@@ -1,10 +1,12 @@
# Python Inference
-- [1. Layout Structured Analysis](#1)
- - [1.1 layout analysis + table recognition](#1.1)
- - [1.2 layout analysis](#1.2)
- - [1.3 table recognition](#1.3)
-- [2. Key Information Extraction](#2)
+- [1. Layout Structured Analysis](#1-layout-structured-analysis)
+ - [1.1 layout analysis + table recognition](#11-layout-analysis--table-recognition)
+ - [1.2 layout analysis](#12-layout-analysis)
+ - [1.3 table recognition](#13-table-recognition)
+- [2. Key Information Extraction](#2-key-information-extraction)
+ - [2.1 SER](#21-ser)
+ - [2.2 RE+SER](#22-reser)
## 1. Layout Structured Analysis
@@ -72,6 +74,7 @@ After the operation is completed, each image will have a directory with the same
## 2. Key Information Extraction
+### 2.1 SER
```bash
cd ppstructure
@@ -79,13 +82,39 @@ mkdir inference && cd inference
# download model
wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar && tar -xf ser_vi_layoutxlm_xfund_infer.tar
cd ..
-python3 kie/predict_kie_token_ser.py \
+python3 predict_system.py \
--kie_algorithm=LayoutXLM \
- --ser_model_dir=../inference/ser_vi_layoutxlm_xfund_infer \
+ --ser_model_dir=./inference/ser_vi_layoutxlm_xfund_infer \
--image_dir=./docs/kie/input/zh_val_42.jpg \
--ser_dict_path=../ppocr/utils/dict/kie_dict/xfund_class_list.txt \
--vis_font_path=../doc/fonts/simfang.ttf \
- --ocr_order_method="tb-yx"
+ --ocr_order_method="tb-yx" \
+ --mode=kie
```
After the operation is completed, each image will store the visualized image in the `kie` directory under the directory specified by the `output` field, and the image name is the same as the input image name.
+
+
+### 2.2 RE+SER
+
+```bash
+cd ppstructure
+
+mkdir inference && cd inference
+# download model
+wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar && tar -xf ser_vi_layoutxlm_xfund_infer.tar
+wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_infer.tar && tar -xf re_vi_layoutxlm_xfund_infer.tar
+cd ..
+
+python3 predict_system.py \
+ --kie_algorithm=LayoutXLM \
+ --re_model_dir=./inference/re_vi_layoutxlm_xfund_infer \
+ --ser_model_dir=./inference/ser_vi_layoutxlm_xfund_infer \
+ --image_dir=./docs/kie/input/zh_val_42.jpg \
+ --ser_dict_path=../ppocr/utils/dict/kie_dict/xfund_class_list.txt \
+ --vis_font_path=../doc/fonts/simfang.ttf \
+ --ocr_order_method="tb-yx" \
+ --mode=kie
+```
+
+After the operation is completed, each image will have a directory with the same name in the `kie` directory under the directory specified by the `output` field, where the visual images and prediction results are stored.
diff --git a/ppstructure/kie/predict_kie_token_ser_re.py b/ppstructure/kie/predict_kie_token_ser_re.py
index c0bb237db6ba53be1e141b06a56421803b76cc5d..b29a8f69dbf99fa4410136277d7d92d0d41b2039 100644
--- a/ppstructure/kie/predict_kie_token_ser_re.py
+++ b/ppstructure/kie/predict_kie_token_ser_re.py
@@ -29,13 +29,11 @@ import tools.infer.utility as utility
from tools.infer_kie_token_ser_re import make_input
from ppocr.postprocess import build_post_process
from ppocr.utils.logging import get_logger
-from ppocr.utils.visual import draw_re_results
+from ppocr.utils.visual import draw_ser_results, draw_re_results
from ppocr.utils.utility import get_image_file_list, check_and_read
from ppstructure.utility import parse_args
from ppstructure.kie.predict_kie_token_ser import SerPredictor
-from paddleocr import PaddleOCR
-
logger = get_logger()
@@ -43,15 +41,20 @@ class SerRePredictor(object):
def __init__(self, args):
self.use_visual_backbone = args.use_visual_backbone
self.ser_engine = SerPredictor(args)
-
- postprocess_params = {'name': 'VQAReTokenLayoutLMPostProcess'}
- self.postprocess_op = build_post_process(postprocess_params)
- self.predictor, self.input_tensor, self.output_tensors, self.config = \
- utility.create_predictor(args, 're', logger)
+ if args.re_model_dir is not None:
+ postprocess_params = {'name': 'VQAReTokenLayoutLMPostProcess'}
+ self.postprocess_op = build_post_process(postprocess_params)
+ self.predictor, self.input_tensor, self.output_tensors, self.config = \
+ utility.create_predictor(args, 're', logger)
+ else:
+ self.predictor = None
def __call__(self, img):
starttime = time.time()
- ser_results, ser_inputs, _ = self.ser_engine(img)
+ ser_results, ser_inputs, ser_elapse = self.ser_engine(img)
+ if self.predictor is None:
+ return ser_results, ser_elapse
+
re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results)
if self.use_visual_backbone == False:
re_input.pop(4)
@@ -79,7 +82,7 @@ class SerRePredictor(object):
def main(args):
image_file_list = get_image_file_list(args.image_dir)
- ser_predictor = SerRePredictor(args)
+ ser_re_predictor = SerRePredictor(args)
count = 0
total_time = 0
@@ -95,7 +98,7 @@ def main(args):
if img is None:
logger.info("error in loading image:{}".format(image_file))
continue
- re_res, elapse = ser_predictor(img)
+ re_res, elapse = ser_re_predictor(img)
re_res = re_res[0]
res_str = '{}\t{}\n'.format(
@@ -105,14 +108,20 @@ def main(args):
"ocr_info": re_res,
}, ensure_ascii=False))
f_w.write(res_str)
-
- img_res = draw_re_results(
- image_file, re_res, font_path=args.vis_font_path)
-
- img_save_path = os.path.join(
- args.output,
- os.path.splitext(os.path.basename(image_file))[0] +
- "_ser_re.jpg")
+ if ser_re_predictor.predictor is not None:
+ img_res = draw_re_results(
+ image_file, re_res, font_path=args.vis_font_path)
+ img_save_path = os.path.join(
+ args.output,
+ os.path.splitext(os.path.basename(image_file))[0] +
+ "_ser_re.jpg")
+ else:
+ img_res = draw_ser_results(
+ image_file, re_res, font_path=args.vis_font_path)
+ img_save_path = os.path.join(
+ args.output,
+ os.path.splitext(os.path.basename(image_file))[0] +
+ "_ser.jpg")
cv2.imwrite(img_save_path, img_res)
logger.info("save vis result to {}".format(img_save_path))
diff --git a/ppstructure/predict_system.py b/ppstructure/predict_system.py
index b827314b8911859faa449c3322ceceaf10769cf6..417002d1ef58471268071f96868617a4c9c52056 100644
--- a/ppstructure/predict_system.py
+++ b/ppstructure/predict_system.py
@@ -30,6 +30,7 @@ from copy import deepcopy
from ppocr.utils.utility import get_image_file_list, check_and_read
from ppocr.utils.logging import get_logger
+from ppocr.utils.visual import draw_ser_results, draw_re_results
from tools.infer.predict_system import TextSystem
from ppstructure.layout.predict_layout import LayoutPredictor
from ppstructure.table.predict_table import TableSystem, to_excel
@@ -75,7 +76,8 @@ class StructureSystem(object):
self.table_system = TableSystem(args)
elif self.mode == 'kie':
- raise NotImplementedError
+ from ppstructure.kie.predict_kie_token_ser_re import SerRePredictor
+ self.kie_predictor = SerRePredictor(args)
def __call__(self, img, return_ocr_result_in_table=False, img_idx=0):
time_dict = {
@@ -176,7 +178,10 @@ class StructureSystem(object):
time_dict['all'] = end - start
return res_list, time_dict
elif self.mode == 'kie':
- raise NotImplementedError
+ re_res, elapse = self.kie_predictor(img)
+ time_dict['kie'] = elapse
+ time_dict['all'] = elapse
+ return re_res[0], time_dict
return None, None
@@ -235,15 +240,32 @@ def main(args):
all_res = []
for index, img in enumerate(imgs):
res, time_dict = structure_sys(img, img_idx=index)
+ img_save_path = os.path.join(save_folder, img_name,
+ 'show_{}.jpg'.format(index))
+ os.makedirs(os.path.join(save_folder, img_name), exist_ok=True)
if structure_sys.mode == 'structure' and res != []:
- save_structure_res(res, save_folder, img_name, index)
draw_img = draw_structure_result(img, res, args.vis_font_path)
- img_save_path = os.path.join(save_folder, img_name,
- 'show_{}.jpg'.format(index))
+ save_structure_res(res, save_folder, img_name, index)
elif structure_sys.mode == 'kie':
- raise NotImplementedError
- # draw_img = draw_ser_results(img, res, args.vis_font_path)
- # img_save_path = os.path.join(save_folder, img_name + '.jpg')
+ if structure_sys.kie_predictor.predictor is not None:
+ draw_img = draw_re_results(
+ img, res, font_path=args.vis_font_path)
+ else:
+ draw_img = draw_ser_results(
+ img, res, font_path=args.vis_font_path)
+
+ with open(
+ os.path.join(save_folder, img_name,
+ 'res_{}_kie.txt'.format(index)),
+ 'w',
+ encoding='utf8') as f:
+ res_str = '{}\t{}\n'.format(
+ image_file,
+ json.dumps(
+ {
+ "ocr_info": res
+ }, ensure_ascii=False))
+ f.write(res_str)
if res != []:
cv2.imwrite(img_save_path, draw_img)
logger.info('result save to {}'.format(img_save_path))
diff --git a/ppstructure/utility.py b/ppstructure/utility.py
index 59b58edb4b0c9c5992981073b12e419fe1cc84d6..7f8a06d2ec1cd18f19975542667cc0f2cf8ad825 100644
--- a/ppstructure/utility.py
+++ b/ppstructure/utility.py
@@ -11,9 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import random
import ast
-from PIL import Image
+from PIL import Image, ImageDraw, ImageFont
import numpy as np
from tools.infer.utility import draw_ocr_box_txt, str2bool, init_args as infer_args
@@ -64,6 +64,7 @@ def init_args():
parser.add_argument(
"--mode",
type=str,
+ choices=['structure', 'kie'],
default='structure',
help='structure and kie is supported')
parser.add_argument(