From bc0d766425e891e001dbf904079658f55f9ab60f Mon Sep 17 00:00:00 2001 From: WenmuZhou Date: Sat, 5 Jun 2021 22:00:17 +0800 Subject: [PATCH] init commit for paddlestructure --- MANIFEST.in | 2 +- paddleocr.py | 3 + ppocr/utils/dict/table_dict.txt | 3 +- ppstructure/MANIFEST.in | 9 ++ ppstructure/layout/README.md | 0 ppstructure/layout/README_ch.md | 0 ppstructure/paddlestructure.py | 161 +++++++++++++++++++++++++ ppstructure/predict_system.py | 102 ++++++++-------- ppstructure/setup.py | 65 ++++++++++ ppstructure/table/README_ch.md | 15 +++ ppstructure/table/eval_table.py | 15 ++- ppstructure/table/matcher.py | 18 --- ppstructure/table/predict_structure.py | 14 +-- ppstructure/table/predict_table.py | 31 +++-- ppstructure/utility.py | 40 ++++++ tools/infer/predict_system.py | 8 +- tools/infer/utility.py | 9 +- 17 files changed, 385 insertions(+), 110 deletions(-) create mode 100644 ppstructure/MANIFEST.in delete mode 100644 ppstructure/layout/README.md delete mode 100644 ppstructure/layout/README_ch.md create mode 100644 ppstructure/paddlestructure.py create mode 100644 ppstructure/setup.py create mode 100644 ppstructure/utility.py diff --git a/MANIFEST.in b/MANIFEST.in index e16f157d..cd34d574 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,7 +1,7 @@ include LICENSE.txt include README.md -recursive-include ppocr/utils *.txt utility.py logging.py +recursive-include ppocr/utils *.txt utility.py logging.py network.py recursive-include ppocr/data/ *.py recursive-include ppocr/postprocess *.py recursive-include tools/infer *.py diff --git a/paddleocr.py b/paddleocr.py index 708f20b1..48c8c9c6 100644 --- a/paddleocr.py +++ b/paddleocr.py @@ -19,6 +19,7 @@ __dir__ = os.path.dirname(__file__) sys.path.append(os.path.join(__dir__, '')) import cv2 +import logging import numpy as np from pathlib import Path @@ -150,6 +151,8 @@ class PaddleOCR(predict_system.TextSystem): """ params = parse_args(mMain=False) params.__dict__.update(**kwargs) + if params.show_log: + logger.setLevel(logging.DEBUG) self.use_angle_cls = params.use_angle_cls lang = params.lang latin_lang = [ diff --git a/ppocr/utils/dict/table_dict.txt b/ppocr/utils/dict/table_dict.txt index 804f3e31..2ef028c7 100644 --- a/ppocr/utils/dict/table_dict.txt +++ b/ppocr/utils/dict/table_dict.txt @@ -33,8 +33,7 @@ D Π H ║ - + L Φ Χ diff --git a/ppstructure/MANIFEST.in b/ppstructure/MANIFEST.in new file mode 100644 index 00000000..f9bd0fe9 --- /dev/null +++ b/ppstructure/MANIFEST.in @@ -0,0 +1,9 @@ +include LICENSE.txt +include README.md + +recursive-include ppocr/utils *.txt utility.py logging.py network.py +recursive-include ppocr/data/ *.py +recursive-include ppocr/postprocess *.py +recursive-include tools/infer *.py +recursive-include table *.py +recursive-include ppstructure *.py \ No newline at end of file diff --git a/ppstructure/layout/README.md b/ppstructure/layout/README.md deleted file mode 100644 index e69de29b..00000000 diff --git a/ppstructure/layout/README_ch.md b/ppstructure/layout/README_ch.md deleted file mode 100644 index e69de29b..00000000 diff --git a/ppstructure/paddlestructure.py b/ppstructure/paddlestructure.py new file mode 100644 index 00000000..c2db42c1 --- /dev/null +++ b/ppstructure/paddlestructure.py @@ -0,0 +1,161 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +import sys + +__dir__ = os.path.dirname(__file__) +sys.path.append(os.path.join(__dir__, '')) + + +import cv2 +import numpy as np +from pathlib import Path + +from ppocr.utils.logging import get_logger +from predict_system import OCRSystem, save_res +from utility import init_args + +logger = get_logger() +from ppocr.utils.utility import check_and_read_gif, get_image_file_list +from ppocr.utils.network import maybe_download, download_with_progressbar + +__all__ = ['PaddleStructure'] + +VERSION = '2.1' +BASE_DIR = os.path.expanduser("~/.paddlestructure/") + +model_urls = { + 'det': { + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar', + }, + 'rec': { + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar', + }, + 'structure': { + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar', + }, +} + + +def parse_args(mMain=True): + import argparse + parser = init_args() + parser.add_help = mMain + + for action in parser._actions: + if action.dest in ['rec_char_dict_path', 'structure_char_dict_path']: + action.default = None + if mMain: + return parser.parse_args() + else: + inference_args_dict = {} + for action in parser._actions: + inference_args_dict[action.dest] = action.default + return argparse.Namespace(**inference_args_dict) + + +class PaddleStructure(OCRSystem): + def __init__(self, **kwargs): + params = parse_args(mMain=False) + params.__dict__.update(**kwargs) + if params.show_log: + logger.setLevel(logging.DEBUG) + params.use_angle_cls = False + # init model dir + if params.det_model_dir is None: + params.det_model_dir = os.path.join(BASE_DIR, VERSION, 'det') + if params.rec_model_dir is None: + params.rec_model_dir = os.path.join(BASE_DIR, VERSION, 'rec') + if params.structure_model_dir is None: + params.structure_model_dir = os.path.join(BASE_DIR, VERSION, 'structure') + # download model + maybe_download(params.det_model_dir, model_urls['det']) + maybe_download(params.det_model_dir, model_urls['rec']) + maybe_download(params.det_model_dir, model_urls['structure']) + + if params.rec_char_dict_path is None: + params.rec_char_type = 'EN' + if os.path.exists(str(Path(__file__).parent / 'ppocr/utils/dict/table_dict.txt')): + params.rec_char_dict_path = str(Path(__file__).parent / 'ppocr/utils/dict/table_dict.txt') + else: + params.rec_char_dict_path = str(Path(__file__).parent.parent / 'ppocr/utils/dict/table_dict.txt') + if params.structure_char_dict_path is None: + if os.path.exists(str(Path(__file__).parent / 'ppocr/utils/dict/table_structure_dict.txt')): + params.structure_char_dict_path = str(Path(__file__).parent / 'ppocr/utils/dict/table_structure_dict.txt') + else: + params.structure_char_dict_path = str(Path(__file__).parent.parent / 'ppocr/utils/dict/table_structure_dict.txt') + + print(params) + super().__init__(params) + + def __call__(self, img): + if isinstance(img, str): + # download net image + if img.startswith('http'): + download_with_progressbar(img, 'tmp.jpg') + img = 'tmp.jpg' + image_file = img + img, flag = check_and_read_gif(image_file) + if not flag: + with open(image_file, 'rb') as f: + np_arr = np.frombuffer(f.read(), dtype=np.uint8) + img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR) + if img is None: + logger.error("error in loading image:{}".format(image_file)) + return None + if isinstance(img, np.ndarray) and len(img.shape) == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + res = super().__call__(img) + return res + + +def main(): + # for cmd + args = parse_args(mMain=True) + image_dir = args.image_dir + save_folder = args.output + if image_dir.startswith('http'): + download_with_progressbar(image_dir, 'tmp.jpg') + image_file_list = ['tmp.jpg'] + else: + image_file_list = get_image_file_list(args.image_dir) + if len(image_file_list) == 0: + logger.error('no images find in {}'.format(args.image_dir)) + return + + structure_engine = PaddleStructure(**(args.__dict__)) + for img_path in image_file_list: + img_name = os.path.basename(img_path).split('.')[0] + logger.info('{}{}{}'.format('*' * 10, img_path, '*' * 10)) + result = structure_engine(img_path) + save_res(result, args.output, os.path.basename(img_path).split('.')[0]) + for item in result: + logger.info(item['res']) + save_res(result, save_folder, img_name) + logger.info('result save to {}'.format(os.path.join(save_folder, img_name))) + + + +if __name__ == '__main__': + table_engine = PaddleStructure(det_model_dir='../inference/table/ch_ppocr_mobile_v2.0_table_det_infer', + rec_model_dir='../inference/table/ch_ppocr_mobile_v2.0_table_rec_infer', + structure_model_dir='../inference/table/ch_ppocr_mobile_v2.0_table_structure_infer', + output='/Users/zhoujun20/Desktop/工作相关/table/table_pr/PaddleOCR/output/table', + show_log=True) + img = cv2.imread('/Users/zhoujun20/Desktop/工作相关/table/table_pr/PaddleOCR/ppstructure/test_imgs/table_1.png') + result = table_engine(img) + for line in result: + print(line) diff --git a/ppstructure/predict_system.py b/ppstructure/predict_system.py index f6852ab7..e40aa8a8 100644 --- a/ppstructure/predict_system.py +++ b/ppstructure/predict_system.py @@ -18,97 +18,93 @@ import subprocess __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) -sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) +sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) os.environ["FLAGS_allocator_strategy"] = 'auto_growth' import cv2 -import copy import numpy as np import time -import tools.infer.utility as utility -from tools.infer.predict_system import TextSystem -from ppstructure.table.predict_table import TableSystem, to_excel -from ppstructure.layout.predict_layout import LayoutDetector + +import layoutparser as lp + from ppocr.utils.utility import get_image_file_list, check_and_read_gif from ppocr.utils.logging import get_logger +from tools.infer.predict_system import TextSystem +from ppstructure.table.predict_table import TableSystem, to_excel +from ppstructure.utility import parse_args logger = get_logger() -def parse_args(): - parser = utility.init_args() - - # params for output - parser.add_argument("--table_output", type=str, default='output/table') - # params for table structure - parser.add_argument("--table_max_len", type=int, default=488) - parser.add_argument("--table_max_text_length", type=int, default=100) - parser.add_argument("--table_max_elem_length", type=int, default=800) - parser.add_argument("--table_max_cell_num", type=int, default=500) - parser.add_argument("--table_model_dir", type=str) - parser.add_argument("--table_char_type", type=str, default='en') - parser.add_argument("--table_char_dict_path", type=str, default="./ppocr/utils/dict/table_structure_dict.txt") - - # params for layout detector - parser.add_argument("--layout_model_dir", type=str) - return parser.parse_args() - - -class OCRSystem(): +class OCRSystem(object): def __init__(self, args): self.text_system = TextSystem(args) - self.table_system = TableSystem(args) - self.table_layout = LayoutDetector(args) + self.table_system = TableSystem(args, self.text_system.text_detector, self.text_system.text_recognizer) + self.table_layout = lp.PaddleDetectionLayoutModel("lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config", + threshold=0.5, enable_mkldnn=args.enable_mkldnn, + enforce_cpu=not args.use_gpu) self.use_angle_cls = args.use_angle_cls self.drop_score = args.drop_score def __call__(self, img): ori_im = img.copy() - layout_res = self.table_layout(copy.deepcopy(img)) + layout_res = self.table_layout.detect(img[..., ::-1]) + res_list = [] for region in layout_res: - x1, y1, x2, y2 = region['bbox'] + x1, y1, x2, y2 = region.coordinates + x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) roi_img = ori_im[y1:y2, x1:x2, :] - if region['label'] == 'table': - res = self.text_system(roi_img) + if region.type == 'Table': + res = self.table_system(roi_img) + elif region.type == 'Figure': + continue else: - res = self.text_system(roi_img) - region['res'] = res - return layout_res + filter_boxes, filter_rec_res = self.text_system(roi_img) + filter_boxes = [x.reshape(-1).tolist() for x in filter_boxes] + res = (filter_boxes, filter_rec_res) + res_list.append({'type': region.type, 'bbox': [x1, y1, x2, y2], 'res': res}) + return res_list + + +def save_res(res, save_folder, img_name): + excel_save_folder = os.path.join(save_folder, img_name) + os.makedirs(excel_save_folder, exist_ok=True) + # save res + for region in res: + if region['type'] == 'Table': + excel_path = os.path.join(excel_save_folder, '{}.xlsx'.format(region['bbox'])) + to_excel(region['res'], excel_path) + elif region['type'] == 'Figure': + pass + else: + with open(os.path.join(excel_save_folder, 'res.txt'), 'a', encoding='utf8') as f: + for box, rec_res in zip(*region['res']): + f.write('{}\t{}\n'.format(np.array(box).reshape(-1).tolist(), rec_res)) def main(args): image_file_list = get_image_file_list(args.image_dir) + image_file_list = image_file_list image_file_list = image_file_list[args.process_id::args.total_process_num] - save_folder = args.table_output + save_folder = args.output os.makedirs(save_folder, exist_ok=True) - text_sys = OCRSystem(args) + structure_sys = OCRSystem(args) img_num = len(image_file_list) for i, image_file in enumerate(image_file_list): logger.info("[{}/{}] {}".format(i, img_num, image_file)) img, flag = check_and_read_gif(image_file) img_name = os.path.basename(image_file).split('.')[0] - # excel_path = os.path.join(excel_save_folder, + '.xlsx') + if not flag: img = cv2.imread(image_file) if img is None: - logger.info("error in loading image:{}".format(image_file)) + logger.error("error in loading image:{}".format(image_file)) continue starttime = time.time() - res = text_sys(img) - - excel_save_folder = os.path.join(save_folder, img_name) - os.makedirs(excel_save_folder, exist_ok=True) - # save res - for region in res: - if region['label'] == 'table': - excel_path = os.path.join(excel_save_folder, '{}.xlsx'.format(region['bbox'])) - to_excel(region['res'], excel_path) - else: - with open(os.path.join(excel_save_folder, 'res.txt'),'a',encoding='utf8') as f: - for box, rec_res in zip(*region['res']): - f.write('{}\t{}\n'.format(np.array(box).reshape(-1).tolist(), rec_res)) - logger.info(res) + res = structure_sys(img) + save_res(res, save_folder, img_name) + logger.info('result save to {}'.format(os.path.join(save_folder, img_name))) elapse = time.time() - starttime logger.info("Predict time : {:.3f}s".format(elapse)) diff --git a/ppstructure/setup.py b/ppstructure/setup.py new file mode 100644 index 00000000..493599b7 --- /dev/null +++ b/ppstructure/setup.py @@ -0,0 +1,65 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from setuptools import setup +from io import open +import shutil + +with open('../requirements.txt', encoding="utf-8-sig") as f: + requirements = f.readlines() + requirements.append('tqdm') + requirements.append('layoutparser') + + +def readme(): + with open('README_ch.md', encoding="utf-8-sig") as f: + README = f.read() + return README + +shutil.copytree('../ppocr','./ppocr') +shutil.copytree('../tools','./tools') +shutil.copytree('../ppstructure','./ppstructure') + +setup( + name='paddlestructure', + packages=['paddlestructure'], + package_dir={'paddlestructure': ''}, + include_package_data=True, + entry_points={"console_scripts": ["paddlestructure= paddlestructure.paddlestructure:main"]}, + version='2.0.6', + install_requires=requirements, + license='Apache License 2.0', + description='Awesome OCR toolkits based on PaddlePaddle (8.6M ultra-lightweight pre-trained model, support training and deployment among server, mobile, embeded and IoT devices', + long_description=readme(), + long_description_content_type='text/markdown', + url='https://github.com/PaddlePaddle/PaddleOCR', + download_url='https://github.com/PaddlePaddle/PaddleOCR.git', + keywords=[ + 'ocr textdetection textrecognition paddleocr crnn east star-net rosetta ocrlite db chineseocr chinesetextdetection chinesetextrecognition' + ], + classifiers=[ + 'Intended Audience :: Developers', 'Operating System :: OS Independent', + 'Natural Language :: Chinese (Simplified)', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.2', + 'Programming Language :: Python :: 3.3', + 'Programming Language :: Python :: 3.4', + 'Programming Language :: Python :: 3.5', + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', 'Topic :: Utilities' + ], ) + +shutil.rmtree('ppocr') +shutil.rmtree('tools') +shutil.rmtree('ppstructure') \ No newline at end of file diff --git a/ppstructure/table/README_ch.md b/ppstructure/table/README_ch.md index e69de29b..effd1cf2 100644 --- a/ppstructure/table/README_ch.md +++ b/ppstructure/table/README_ch.md @@ -0,0 +1,15 @@ +# 表格结构和内容预测 + +先cd到PaddleOCR/ppstructure目录下 + +预测 +```python +python3 table/predict_table.py --det_model_dir=../inference/db --rec_model_dir=../inference/rec_mv3_large1.0/infer --table_model_dir=../inference/explite3/infer --image_dir=../table/imgs/PMC3006023_004_00.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --table_output ../output/table +``` +运行完成后,每张图片的excel表格会保存到table_output字段指定的目录下 + +eval + +```python +python3 table/eval_table.py --det_model_dir=../inference/db --rec_model_dir=../inference/rec_mv3_large1.0/infer --table_model_dir=../inference/explite3/infer --image_dir=../table/imgs --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --gt_path=path/to/gt.json +``` diff --git a/ppstructure/table/eval_table.py b/ppstructure/table/eval_table.py index 46df68df..0ba7acbc 100755 --- a/ppstructure/table/eval_table.py +++ b/ppstructure/table/eval_table.py @@ -15,16 +15,21 @@ import os import sys __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) -sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) +sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) import cv2 import json from tqdm import tqdm from ppstructure.table.table_metric import TEDS from ppstructure.table.predict_table import TableSystem -from ppstructure.predict_system import parse_args +from ppstructure.utility import init_args +def parse_args(): + parser = init_args() + parser.add_argument("--gt_path", type=str) + return parser.parse_args() + def main(gt_path, img_root, args): teds = TEDS(n_jobs=16) @@ -33,6 +38,8 @@ def main(gt_path, img_root, args): pred_htmls = [] gt_htmls = [] for img_name in tqdm(jsons_gt): + if img_name != 'PMC1064865_002_00.png': + continue # 读取信息 img = cv2.imread(os.path.join(img_root,img_name)) pred_html = text_sys(img) @@ -61,6 +68,4 @@ def get_gt_html(gt_structures, contents_with_block): if __name__ == '__main__': args = parse_args() - gt_path = 'table/match_code/f_gt_bbox.json' - img_root = 'table/imgs' - main(gt_path,img_root, args) + main(args.gt_path,args.image_dir, args) diff --git a/ppstructure/table/matcher.py b/ppstructure/table/matcher.py index 711806aa..b3c70430 100755 --- a/ppstructure/table/matcher.py +++ b/ppstructure/table/matcher.py @@ -194,21 +194,3 @@ def matcher_structure(gt_bboxes, pred_bboxes_rows, pred_bboxes): matched[index].append(i) pre_bbox = gt_box return matched - - -def main(): - detect_bboxes = json.load(open('./f_detecion_bbox.json')) - gt_bboxes = json.load(open('./f_gt_bbox.json')) - all_node = 0 - matched_right = 0 - key = 'PMC4796501_003_00.png' - print(key) - gt_bbox = gt_bboxes[key] - pred_bbox = detect_bboxes[key] - matched = matcher(gt_bbox, pred_bbox) - print(matched) - - -if __name__ == "__main__": - main() - diff --git a/ppstructure/table/predict_structure.py b/ppstructure/table/predict_structure.py index fd00dfd1..6e680b35 100755 --- a/ppstructure/table/predict_structure.py +++ b/ppstructure/table/predict_structure.py @@ -40,7 +40,7 @@ class TableStructurer(object): def __init__(self, args): pre_process_list = [{ 'ResizeTableImage': { - 'max_len': args.table_max_len + 'max_len': args.structure_max_len } }, { 'NormalizeImage': { @@ -60,17 +60,17 @@ class TableStructurer(object): }] postprocess_params = { 'name': 'TableLabelDecode', - "character_type": args.table_char_type, - "character_dict_path": args.table_char_dict_path, - "max_text_length": args.table_max_text_length, - "max_elem_length": args.table_max_elem_length, - "max_cell_num": args.table_max_cell_num + "character_type": args.structure_char_type, + "character_dict_path": args.structure_char_dict_path, + "max_text_length": args.structure_max_text_length, + "max_elem_length": args.structure_max_elem_length, + "max_cell_num": args.structure_max_cell_num } self.preprocess_op = create_operators(pre_process_list) self.postprocess_op = build_post_process(postprocess_params) self.predictor, self.input_tensor, self.output_tensors = \ - utility.create_predictor(args, 'table', logger) + utility.create_predictor(args, 'structure', logger) def __call__(self, img): ori_im = img.copy() diff --git a/ppstructure/table/predict_table.py b/ppstructure/table/predict_table.py index 30e503e5..4a247e40 100644 --- a/ppstructure/table/predict_table.py +++ b/ppstructure/table/predict_table.py @@ -18,6 +18,7 @@ import subprocess __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) os.environ["FLAGS_allocator_strategy"] = 'auto_growth' @@ -25,13 +26,13 @@ import cv2 import copy import numpy as np import time -import tools.infer.utility as utility import tools.infer.predict_rec as predict_rec import tools.infer.predict_det as predict_det import ppstructure.table.predict_structure as predict_strture from ppocr.utils.utility import get_image_file_list, check_and_read_gif from ppocr.utils.logging import get_logger -from ppstructure.table.matcher import distance, compute_iou +from matcher import distance, compute_iou +from ppstructure.utility import parse_args logger = get_logger() @@ -52,12 +53,10 @@ def expand(pix, det_box, shape): class TableSystem(object): - def __init__(self, args): - self.text_detector = predict_det.TextDetector(args) - self.text_recognizer = predict_rec.TextRecognizer(args) + def __init__(self, args, text_detector=None, text_recognizer=None): + self.text_detector = predict_det.TextDetector(args) if text_detector is None else text_detector + self.text_recognizer = predict_rec.TextRecognizer(args) if text_recognizer is None else text_recognizer self.table_structurer = predict_strture.TableStructurer(args) - self.use_angle_cls = args.use_angle_cls - self.drop_score = args.drop_score def __call__(self, img): ori_im = img.copy() @@ -75,8 +74,8 @@ class TableSystem(object): r_boxes.append(box) dt_boxes = np.array(r_boxes) - # logger.info("dt_boxes num : {}, elapse : {}".format( - # len(dt_boxes), elapse)) + logger.debug("dt_boxes num : {}, elapse : {}".format( + len(dt_boxes), elapse)) if dt_boxes is None: return None, None img_crop_list = [] @@ -87,8 +86,8 @@ class TableSystem(object): text_rect = ori_im[int(y0):int(y1), int(x0):int(x1), :] img_crop_list.append(text_rect) rec_res, elapse = self.text_recognizer(img_crop_list) - # logger.info("rec_res num : {}, elapse : {}".format( - # len(rec_res), elapse)) + logger.debug("rec_res num : {}, elapse : {}".format( + len(rec_res), elapse)) pred_html, pred = self.rebuild_table(structure_res, dt_boxes, rec_res) return pred_html @@ -172,6 +171,7 @@ def sorted_boxes(dt_boxes): _boxes[i + 1] = tmp return _boxes + def to_excel(html_table, excel_path): from tablepyxl import tablepyxl tablepyxl.document_to_xl(html_table, excel_path) @@ -180,19 +180,18 @@ def to_excel(html_table, excel_path): def main(args): image_file_list = get_image_file_list(args.image_dir) image_file_list = image_file_list[args.process_id::args.total_process_num] - excel_save_folder = 'output/table' - os.makedirs(excel_save_folder, exist_ok=True) + os.makedirs(args.output, exist_ok=True) text_sys = TableSystem(args) img_num = len(image_file_list) for i, image_file in enumerate(image_file_list): logger.info("[{}/{}] {}".format(i, img_num, image_file)) img, flag = check_and_read_gif(image_file) - excel_path = os.path.join(excel_save_folder, os.path.basename(image_file).split('.')[0] + '.xlsx') + excel_path = os.path.join(args.table_output, os.path.basename(image_file).split('.')[0] + '.xlsx') if not flag: img = cv2.imread(image_file) if img is None: - logger.info("error in loading image:{}".format(image_file)) + logger.error("error in loading image:{}".format(image_file)) continue starttime = time.time() pred_html = text_sys(img) @@ -205,7 +204,7 @@ def main(args): if __name__ == "__main__": - args = utility.parse_args() + args = parse_args() if args.use_mp: p_list = [] total_process_num = args.total_process_num diff --git a/ppstructure/utility.py b/ppstructure/utility.py new file mode 100644 index 00000000..57659920 --- /dev/null +++ b/ppstructure/utility.py @@ -0,0 +1,40 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from tools.infer.utility import str2bool, init_args as infer_args + + +def init_args(): + parser = infer_args() + + # params for output + parser.add_argument("--output", type=str, default='./output/table') + # params for table structure + parser.add_argument("--structure_max_len", type=int, default=488) + parser.add_argument("--structure_max_text_length", type=int, default=100) + parser.add_argument("--structure_max_elem_length", type=int, default=800) + parser.add_argument("--structure_max_cell_num", type=int, default=500) + parser.add_argument("--structure_model_dir", type=str) + parser.add_argument("--structure_char_type", type=str, default='en') + parser.add_argument("--structure_char_dict_path", type=str, default="../ppocr/utils/dict/table_structure_dict.txt") + + # params for layout detector + parser.add_argument("--layout_model_dir", type=str) + return parser + + +def parse_args(): + parser = init_args() + return parser.parse_args() diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py index 78f5a472..235a075b 100755 --- a/tools/infer/predict_system.py +++ b/tools/infer/predict_system.py @@ -88,7 +88,7 @@ class TextSystem(object): def __call__(self, img, cls=True): ori_im = img.copy() dt_boxes, elapse = self.text_detector(img) - logger.info("dt_boxes num : {}, elapse : {}".format( + logger.debug("dt_boxes num : {}, elapse : {}".format( len(dt_boxes), elapse)) if dt_boxes is None: return None, None @@ -103,11 +103,11 @@ class TextSystem(object): if self.use_angle_cls and cls: img_crop_list, angle_list, elapse = self.text_classifier( img_crop_list) - logger.info("cls num : {}, elapse : {}".format( + logger.debug("cls num : {}, elapse : {}".format( len(img_crop_list), elapse)) rec_res, elapse = self.text_recognizer(img_crop_list) - logger.info("rec_res num : {}, elapse : {}".format( + logger.debug("rec_res num : {}, elapse : {}".format( len(rec_res), elapse)) # self.print_draw_crop_rec_res(img_crop_list, rec_res) filter_boxes, filter_rec_res = [], [] @@ -152,7 +152,7 @@ def main(args): if not flag: img = cv2.imread(image_file) if img is None: - logger.info("error in loading image:{}".format(image_file)) + logger.error("error in loading image:{}".format(image_file)) continue starttime = time.time() dt_boxes, rec_res = text_sys(img) diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 956df5ca..a558f490 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -109,7 +109,7 @@ def init_args(): parser.add_argument("--use_mp", type=str2bool, default=False) parser.add_argument("--total_process_num", type=int, default=1) parser.add_argument("--process_id", type=int, default=0) - + parser.add_argument("--show_log", type=str2bool, default=True) return parser @@ -125,8 +125,8 @@ def create_predictor(args, mode, logger): model_dir = args.cls_model_dir elif mode == 'rec': model_dir = args.rec_model_dir - elif mode == 'table': - model_dir = args.table_model_dir + elif mode == 'structure': + model_dir = args.structure_model_dir else: model_dir = args.e2e_model_dir @@ -246,7 +246,8 @@ def create_predictor(args, mode, logger): config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") config.switch_use_feed_fetch_ops(False) - if mode == 'table': + config.switch_ir_optim(True) + if mode == 'structure': config.switch_ir_optim(False) # create predictor predictor = inference.create_predictor(config) -- GitLab