提交 bc0d7664 编写于 作者: W WenmuZhou

init commit for paddlestructure

上级 a5f75115
include LICENSE.txt include LICENSE.txt
include README.md include README.md
recursive-include ppocr/utils *.txt utility.py logging.py recursive-include ppocr/utils *.txt utility.py logging.py network.py
recursive-include ppocr/data/ *.py recursive-include ppocr/data/ *.py
recursive-include ppocr/postprocess *.py recursive-include ppocr/postprocess *.py
recursive-include tools/infer *.py recursive-include tools/infer *.py
......
...@@ -19,6 +19,7 @@ __dir__ = os.path.dirname(__file__) ...@@ -19,6 +19,7 @@ __dir__ = os.path.dirname(__file__)
sys.path.append(os.path.join(__dir__, '')) sys.path.append(os.path.join(__dir__, ''))
import cv2 import cv2
import logging
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path
...@@ -150,6 +151,8 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -150,6 +151,8 @@ class PaddleOCR(predict_system.TextSystem):
""" """
params = parse_args(mMain=False) params = parse_args(mMain=False)
params.__dict__.update(**kwargs) params.__dict__.update(**kwargs)
if params.show_log:
logger.setLevel(logging.DEBUG)
self.use_angle_cls = params.use_angle_cls self.use_angle_cls = params.use_angle_cls
lang = params.lang lang = params.lang
latin_lang = [ latin_lang = [
......
...@@ -33,8 +33,7 @@ D ...@@ -33,8 +33,7 @@ D
Π Π
H H
</ </strike>
>
L L
Φ Φ
Χ Χ
......
include LICENSE.txt
include README.md
recursive-include ppocr/utils *.txt utility.py logging.py network.py
recursive-include ppocr/data/ *.py
recursive-include ppocr/postprocess *.py
recursive-include tools/infer *.py
recursive-include table *.py
recursive-include ppstructure *.py
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import sys
__dir__ = os.path.dirname(__file__)
sys.path.append(os.path.join(__dir__, ''))
import cv2
import numpy as np
from pathlib import Path
from ppocr.utils.logging import get_logger
from predict_system import OCRSystem, save_res
from utility import init_args
logger = get_logger()
from ppocr.utils.utility import check_and_read_gif, get_image_file_list
from ppocr.utils.network import maybe_download, download_with_progressbar
__all__ = ['PaddleStructure']
VERSION = '2.1'
BASE_DIR = os.path.expanduser("~/.paddlestructure/")
model_urls = {
'det': {
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
},
'rec': {
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
},
'structure': {
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
},
}
def parse_args(mMain=True):
import argparse
parser = init_args()
parser.add_help = mMain
for action in parser._actions:
if action.dest in ['rec_char_dict_path', 'structure_char_dict_path']:
action.default = None
if mMain:
return parser.parse_args()
else:
inference_args_dict = {}
for action in parser._actions:
inference_args_dict[action.dest] = action.default
return argparse.Namespace(**inference_args_dict)
class PaddleStructure(OCRSystem):
def __init__(self, **kwargs):
params = parse_args(mMain=False)
params.__dict__.update(**kwargs)
if params.show_log:
logger.setLevel(logging.DEBUG)
params.use_angle_cls = False
# init model dir
if params.det_model_dir is None:
params.det_model_dir = os.path.join(BASE_DIR, VERSION, 'det')
if params.rec_model_dir is None:
params.rec_model_dir = os.path.join(BASE_DIR, VERSION, 'rec')
if params.structure_model_dir is None:
params.structure_model_dir = os.path.join(BASE_DIR, VERSION, 'structure')
# download model
maybe_download(params.det_model_dir, model_urls['det'])
maybe_download(params.det_model_dir, model_urls['rec'])
maybe_download(params.det_model_dir, model_urls['structure'])
if params.rec_char_dict_path is None:
params.rec_char_type = 'EN'
if os.path.exists(str(Path(__file__).parent / 'ppocr/utils/dict/table_dict.txt')):
params.rec_char_dict_path = str(Path(__file__).parent / 'ppocr/utils/dict/table_dict.txt')
else:
params.rec_char_dict_path = str(Path(__file__).parent.parent / 'ppocr/utils/dict/table_dict.txt')
if params.structure_char_dict_path is None:
if os.path.exists(str(Path(__file__).parent / 'ppocr/utils/dict/table_structure_dict.txt')):
params.structure_char_dict_path = str(Path(__file__).parent / 'ppocr/utils/dict/table_structure_dict.txt')
else:
params.structure_char_dict_path = str(Path(__file__).parent.parent / 'ppocr/utils/dict/table_structure_dict.txt')
print(params)
super().__init__(params)
def __call__(self, img):
if isinstance(img, str):
# download net image
if img.startswith('http'):
download_with_progressbar(img, 'tmp.jpg')
img = 'tmp.jpg'
image_file = img
img, flag = check_and_read_gif(image_file)
if not flag:
with open(image_file, 'rb') as f:
np_arr = np.frombuffer(f.read(), dtype=np.uint8)
img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
if img is None:
logger.error("error in loading image:{}".format(image_file))
return None
if isinstance(img, np.ndarray) and len(img.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
res = super().__call__(img)
return res
def main():
# for cmd
args = parse_args(mMain=True)
image_dir = args.image_dir
save_folder = args.output
if image_dir.startswith('http'):
download_with_progressbar(image_dir, 'tmp.jpg')
image_file_list = ['tmp.jpg']
else:
image_file_list = get_image_file_list(args.image_dir)
if len(image_file_list) == 0:
logger.error('no images find in {}'.format(args.image_dir))
return
structure_engine = PaddleStructure(**(args.__dict__))
for img_path in image_file_list:
img_name = os.path.basename(img_path).split('.')[0]
logger.info('{}{}{}'.format('*' * 10, img_path, '*' * 10))
result = structure_engine(img_path)
save_res(result, args.output, os.path.basename(img_path).split('.')[0])
for item in result:
logger.info(item['res'])
save_res(result, save_folder, img_name)
logger.info('result save to {}'.format(os.path.join(save_folder, img_name)))
if __name__ == '__main__':
table_engine = PaddleStructure(det_model_dir='../inference/table/ch_ppocr_mobile_v2.0_table_det_infer',
rec_model_dir='../inference/table/ch_ppocr_mobile_v2.0_table_rec_infer',
structure_model_dir='../inference/table/ch_ppocr_mobile_v2.0_table_structure_infer',
output='/Users/zhoujun20/Desktop/工作相关/table/table_pr/PaddleOCR/output/table',
show_log=True)
img = cv2.imread('/Users/zhoujun20/Desktop/工作相关/table/table_pr/PaddleOCR/ppstructure/test_imgs/table_1.png')
result = table_engine(img)
for line in result:
print(line)
...@@ -18,97 +18,93 @@ import subprocess ...@@ -18,97 +18,93 @@ import subprocess
__dir__ = os.path.dirname(os.path.abspath(__file__)) __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__) sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth' os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import cv2 import cv2
import copy
import numpy as np import numpy as np
import time import time
import tools.infer.utility as utility
from tools.infer.predict_system import TextSystem import layoutparser as lp
from ppstructure.table.predict_table import TableSystem, to_excel
from ppstructure.layout.predict_layout import LayoutDetector
from ppocr.utils.utility import get_image_file_list, check_and_read_gif from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
from tools.infer.predict_system import TextSystem
from ppstructure.table.predict_table import TableSystem, to_excel
from ppstructure.utility import parse_args
logger = get_logger() logger = get_logger()
def parse_args(): class OCRSystem(object):
parser = utility.init_args()
# params for output
parser.add_argument("--table_output", type=str, default='output/table')
# params for table structure
parser.add_argument("--table_max_len", type=int, default=488)
parser.add_argument("--table_max_text_length", type=int, default=100)
parser.add_argument("--table_max_elem_length", type=int, default=800)
parser.add_argument("--table_max_cell_num", type=int, default=500)
parser.add_argument("--table_model_dir", type=str)
parser.add_argument("--table_char_type", type=str, default='en')
parser.add_argument("--table_char_dict_path", type=str, default="./ppocr/utils/dict/table_structure_dict.txt")
# params for layout detector
parser.add_argument("--layout_model_dir", type=str)
return parser.parse_args()
class OCRSystem():
def __init__(self, args): def __init__(self, args):
self.text_system = TextSystem(args) self.text_system = TextSystem(args)
self.table_system = TableSystem(args) self.table_system = TableSystem(args, self.text_system.text_detector, self.text_system.text_recognizer)
self.table_layout = LayoutDetector(args) self.table_layout = lp.PaddleDetectionLayoutModel("lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config",
threshold=0.5, enable_mkldnn=args.enable_mkldnn,
enforce_cpu=not args.use_gpu)
self.use_angle_cls = args.use_angle_cls self.use_angle_cls = args.use_angle_cls
self.drop_score = args.drop_score self.drop_score = args.drop_score
def __call__(self, img): def __call__(self, img):
ori_im = img.copy() ori_im = img.copy()
layout_res = self.table_layout(copy.deepcopy(img)) layout_res = self.table_layout.detect(img[..., ::-1])
res_list = []
for region in layout_res: for region in layout_res:
x1, y1, x2, y2 = region['bbox'] x1, y1, x2, y2 = region.coordinates
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
roi_img = ori_im[y1:y2, x1:x2, :] roi_img = ori_im[y1:y2, x1:x2, :]
if region['label'] == 'table': if region.type == 'Table':
res = self.text_system(roi_img) res = self.table_system(roi_img)
elif region.type == 'Figure':
continue
else: else:
res = self.text_system(roi_img) filter_boxes, filter_rec_res = self.text_system(roi_img)
region['res'] = res filter_boxes = [x.reshape(-1).tolist() for x in filter_boxes]
return layout_res res = (filter_boxes, filter_rec_res)
res_list.append({'type': region.type, 'bbox': [x1, y1, x2, y2], 'res': res})
return res_list
def save_res(res, save_folder, img_name):
excel_save_folder = os.path.join(save_folder, img_name)
os.makedirs(excel_save_folder, exist_ok=True)
# save res
for region in res:
if region['type'] == 'Table':
excel_path = os.path.join(excel_save_folder, '{}.xlsx'.format(region['bbox']))
to_excel(region['res'], excel_path)
elif region['type'] == 'Figure':
pass
else:
with open(os.path.join(excel_save_folder, 'res.txt'), 'a', encoding='utf8') as f:
for box, rec_res in zip(*region['res']):
f.write('{}\t{}\n'.format(np.array(box).reshape(-1).tolist(), rec_res))
def main(args): def main(args):
image_file_list = get_image_file_list(args.image_dir) image_file_list = get_image_file_list(args.image_dir)
image_file_list = image_file_list
image_file_list = image_file_list[args.process_id::args.total_process_num] image_file_list = image_file_list[args.process_id::args.total_process_num]
save_folder = args.table_output save_folder = args.output
os.makedirs(save_folder, exist_ok=True) os.makedirs(save_folder, exist_ok=True)
text_sys = OCRSystem(args) structure_sys = OCRSystem(args)
img_num = len(image_file_list) img_num = len(image_file_list)
for i, image_file in enumerate(image_file_list): for i, image_file in enumerate(image_file_list):
logger.info("[{}/{}] {}".format(i, img_num, image_file)) logger.info("[{}/{}] {}".format(i, img_num, image_file))
img, flag = check_and_read_gif(image_file) img, flag = check_and_read_gif(image_file)
img_name = os.path.basename(image_file).split('.')[0] img_name = os.path.basename(image_file).split('.')[0]
# excel_path = os.path.join(excel_save_folder, + '.xlsx')
if not flag: if not flag:
img = cv2.imread(image_file) img = cv2.imread(image_file)
if img is None: if img is None:
logger.info("error in loading image:{}".format(image_file)) logger.error("error in loading image:{}".format(image_file))
continue continue
starttime = time.time() starttime = time.time()
res = text_sys(img) res = structure_sys(img)
save_res(res, save_folder, img_name)
excel_save_folder = os.path.join(save_folder, img_name) logger.info('result save to {}'.format(os.path.join(save_folder, img_name)))
os.makedirs(excel_save_folder, exist_ok=True)
# save res
for region in res:
if region['label'] == 'table':
excel_path = os.path.join(excel_save_folder, '{}.xlsx'.format(region['bbox']))
to_excel(region['res'], excel_path)
else:
with open(os.path.join(excel_save_folder, 'res.txt'),'a',encoding='utf8') as f:
for box, rec_res in zip(*region['res']):
f.write('{}\t{}\n'.format(np.array(box).reshape(-1).tolist(), rec_res))
logger.info(res)
elapse = time.time() - starttime elapse = time.time() - starttime
logger.info("Predict time : {:.3f}s".format(elapse)) logger.info("Predict time : {:.3f}s".format(elapse))
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from setuptools import setup
from io import open
import shutil
with open('../requirements.txt', encoding="utf-8-sig") as f:
requirements = f.readlines()
requirements.append('tqdm')
requirements.append('layoutparser')
def readme():
with open('README_ch.md', encoding="utf-8-sig") as f:
README = f.read()
return README
shutil.copytree('../ppocr','./ppocr')
shutil.copytree('../tools','./tools')
shutil.copytree('../ppstructure','./ppstructure')
setup(
name='paddlestructure',
packages=['paddlestructure'],
package_dir={'paddlestructure': ''},
include_package_data=True,
entry_points={"console_scripts": ["paddlestructure= paddlestructure.paddlestructure:main"]},
version='2.0.6',
install_requires=requirements,
license='Apache License 2.0',
description='Awesome OCR toolkits based on PaddlePaddle (8.6M ultra-lightweight pre-trained model, support training and deployment among server, mobile, embeded and IoT devices',
long_description=readme(),
long_description_content_type='text/markdown',
url='https://github.com/PaddlePaddle/PaddleOCR',
download_url='https://github.com/PaddlePaddle/PaddleOCR.git',
keywords=[
'ocr textdetection textrecognition paddleocr crnn east star-net rosetta ocrlite db chineseocr chinesetextdetection chinesetextrecognition'
],
classifiers=[
'Intended Audience :: Developers', 'Operating System :: OS Independent',
'Natural Language :: Chinese (Simplified)',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.2',
'Programming Language :: Python :: 3.3',
'Programming Language :: Python :: 3.4',
'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7', 'Topic :: Utilities'
], )
shutil.rmtree('ppocr')
shutil.rmtree('tools')
shutil.rmtree('ppstructure')
\ No newline at end of file
# 表格结构和内容预测
先cd到PaddleOCR/ppstructure目录下
预测
```python
python3 table/predict_table.py --det_model_dir=../inference/db --rec_model_dir=../inference/rec_mv3_large1.0/infer --table_model_dir=../inference/explite3/infer --image_dir=../table/imgs/PMC3006023_004_00.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --table_output ../output/table
```
运行完成后,每张图片的excel表格会保存到table_output字段指定的目录下
eval
```python
python3 table/eval_table.py --det_model_dir=../inference/db --rec_model_dir=../inference/rec_mv3_large1.0/infer --table_model_dir=../inference/explite3/infer --image_dir=../table/imgs --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --gt_path=path/to/gt.json
```
...@@ -15,16 +15,21 @@ import os ...@@ -15,16 +15,21 @@ import os
import sys import sys
__dir__ = os.path.dirname(os.path.abspath(__file__)) __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__) sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import cv2 import cv2
import json import json
from tqdm import tqdm from tqdm import tqdm
from ppstructure.table.table_metric import TEDS from ppstructure.table.table_metric import TEDS
from ppstructure.table.predict_table import TableSystem from ppstructure.table.predict_table import TableSystem
from ppstructure.predict_system import parse_args from ppstructure.utility import init_args
def parse_args():
parser = init_args()
parser.add_argument("--gt_path", type=str)
return parser.parse_args()
def main(gt_path, img_root, args): def main(gt_path, img_root, args):
teds = TEDS(n_jobs=16) teds = TEDS(n_jobs=16)
...@@ -33,6 +38,8 @@ def main(gt_path, img_root, args): ...@@ -33,6 +38,8 @@ def main(gt_path, img_root, args):
pred_htmls = [] pred_htmls = []
gt_htmls = [] gt_htmls = []
for img_name in tqdm(jsons_gt): for img_name in tqdm(jsons_gt):
if img_name != 'PMC1064865_002_00.png':
continue
# 读取信息 # 读取信息
img = cv2.imread(os.path.join(img_root,img_name)) img = cv2.imread(os.path.join(img_root,img_name))
pred_html = text_sys(img) pred_html = text_sys(img)
...@@ -61,6 +68,4 @@ def get_gt_html(gt_structures, contents_with_block): ...@@ -61,6 +68,4 @@ def get_gt_html(gt_structures, contents_with_block):
if __name__ == '__main__': if __name__ == '__main__':
args = parse_args() args = parse_args()
gt_path = 'table/match_code/f_gt_bbox.json' main(args.gt_path,args.image_dir, args)
img_root = 'table/imgs'
main(gt_path,img_root, args)
...@@ -194,21 +194,3 @@ def matcher_structure(gt_bboxes, pred_bboxes_rows, pred_bboxes): ...@@ -194,21 +194,3 @@ def matcher_structure(gt_bboxes, pred_bboxes_rows, pred_bboxes):
matched[index].append(i) matched[index].append(i)
pre_bbox = gt_box pre_bbox = gt_box
return matched return matched
def main():
detect_bboxes = json.load(open('./f_detecion_bbox.json'))
gt_bboxes = json.load(open('./f_gt_bbox.json'))
all_node = 0
matched_right = 0
key = 'PMC4796501_003_00.png'
print(key)
gt_bbox = gt_bboxes[key]
pred_bbox = detect_bboxes[key]
matched = matcher(gt_bbox, pred_bbox)
print(matched)
if __name__ == "__main__":
main()
...@@ -40,7 +40,7 @@ class TableStructurer(object): ...@@ -40,7 +40,7 @@ class TableStructurer(object):
def __init__(self, args): def __init__(self, args):
pre_process_list = [{ pre_process_list = [{
'ResizeTableImage': { 'ResizeTableImage': {
'max_len': args.table_max_len 'max_len': args.structure_max_len
} }
}, { }, {
'NormalizeImage': { 'NormalizeImage': {
...@@ -60,17 +60,17 @@ class TableStructurer(object): ...@@ -60,17 +60,17 @@ class TableStructurer(object):
}] }]
postprocess_params = { postprocess_params = {
'name': 'TableLabelDecode', 'name': 'TableLabelDecode',
"character_type": args.table_char_type, "character_type": args.structure_char_type,
"character_dict_path": args.table_char_dict_path, "character_dict_path": args.structure_char_dict_path,
"max_text_length": args.table_max_text_length, "max_text_length": args.structure_max_text_length,
"max_elem_length": args.table_max_elem_length, "max_elem_length": args.structure_max_elem_length,
"max_cell_num": args.table_max_cell_num "max_cell_num": args.structure_max_cell_num
} }
self.preprocess_op = create_operators(pre_process_list) self.preprocess_op = create_operators(pre_process_list)
self.postprocess_op = build_post_process(postprocess_params) self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors = \ self.predictor, self.input_tensor, self.output_tensors = \
utility.create_predictor(args, 'table', logger) utility.create_predictor(args, 'structure', logger)
def __call__(self, img): def __call__(self, img):
ori_im = img.copy() ori_im = img.copy()
......
...@@ -18,6 +18,7 @@ import subprocess ...@@ -18,6 +18,7 @@ import subprocess
__dir__ = os.path.dirname(os.path.abspath(__file__)) __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__) sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth' os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
...@@ -25,13 +26,13 @@ import cv2 ...@@ -25,13 +26,13 @@ import cv2
import copy import copy
import numpy as np import numpy as np
import time import time
import tools.infer.utility as utility
import tools.infer.predict_rec as predict_rec import tools.infer.predict_rec as predict_rec
import tools.infer.predict_det as predict_det import tools.infer.predict_det as predict_det
import ppstructure.table.predict_structure as predict_strture import ppstructure.table.predict_structure as predict_strture
from ppocr.utils.utility import get_image_file_list, check_and_read_gif from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
from ppstructure.table.matcher import distance, compute_iou from matcher import distance, compute_iou
from ppstructure.utility import parse_args
logger = get_logger() logger = get_logger()
...@@ -52,12 +53,10 @@ def expand(pix, det_box, shape): ...@@ -52,12 +53,10 @@ def expand(pix, det_box, shape):
class TableSystem(object): class TableSystem(object):
def __init__(self, args): def __init__(self, args, text_detector=None, text_recognizer=None):
self.text_detector = predict_det.TextDetector(args) self.text_detector = predict_det.TextDetector(args) if text_detector is None else text_detector
self.text_recognizer = predict_rec.TextRecognizer(args) self.text_recognizer = predict_rec.TextRecognizer(args) if text_recognizer is None else text_recognizer
self.table_structurer = predict_strture.TableStructurer(args) self.table_structurer = predict_strture.TableStructurer(args)
self.use_angle_cls = args.use_angle_cls
self.drop_score = args.drop_score
def __call__(self, img): def __call__(self, img):
ori_im = img.copy() ori_im = img.copy()
...@@ -75,8 +74,8 @@ class TableSystem(object): ...@@ -75,8 +74,8 @@ class TableSystem(object):
r_boxes.append(box) r_boxes.append(box)
dt_boxes = np.array(r_boxes) dt_boxes = np.array(r_boxes)
# logger.info("dt_boxes num : {}, elapse : {}".format( logger.debug("dt_boxes num : {}, elapse : {}".format(
# len(dt_boxes), elapse)) len(dt_boxes), elapse))
if dt_boxes is None: if dt_boxes is None:
return None, None return None, None
img_crop_list = [] img_crop_list = []
...@@ -87,8 +86,8 @@ class TableSystem(object): ...@@ -87,8 +86,8 @@ class TableSystem(object):
text_rect = ori_im[int(y0):int(y1), int(x0):int(x1), :] text_rect = ori_im[int(y0):int(y1), int(x0):int(x1), :]
img_crop_list.append(text_rect) img_crop_list.append(text_rect)
rec_res, elapse = self.text_recognizer(img_crop_list) rec_res, elapse = self.text_recognizer(img_crop_list)
# logger.info("rec_res num : {}, elapse : {}".format( logger.debug("rec_res num : {}, elapse : {}".format(
# len(rec_res), elapse)) len(rec_res), elapse))
pred_html, pred = self.rebuild_table(structure_res, dt_boxes, rec_res) pred_html, pred = self.rebuild_table(structure_res, dt_boxes, rec_res)
return pred_html return pred_html
...@@ -172,6 +171,7 @@ def sorted_boxes(dt_boxes): ...@@ -172,6 +171,7 @@ def sorted_boxes(dt_boxes):
_boxes[i + 1] = tmp _boxes[i + 1] = tmp
return _boxes return _boxes
def to_excel(html_table, excel_path): def to_excel(html_table, excel_path):
from tablepyxl import tablepyxl from tablepyxl import tablepyxl
tablepyxl.document_to_xl(html_table, excel_path) tablepyxl.document_to_xl(html_table, excel_path)
...@@ -180,19 +180,18 @@ def to_excel(html_table, excel_path): ...@@ -180,19 +180,18 @@ def to_excel(html_table, excel_path):
def main(args): def main(args):
image_file_list = get_image_file_list(args.image_dir) image_file_list = get_image_file_list(args.image_dir)
image_file_list = image_file_list[args.process_id::args.total_process_num] image_file_list = image_file_list[args.process_id::args.total_process_num]
excel_save_folder = 'output/table' os.makedirs(args.output, exist_ok=True)
os.makedirs(excel_save_folder, exist_ok=True)
text_sys = TableSystem(args) text_sys = TableSystem(args)
img_num = len(image_file_list) img_num = len(image_file_list)
for i, image_file in enumerate(image_file_list): for i, image_file in enumerate(image_file_list):
logger.info("[{}/{}] {}".format(i, img_num, image_file)) logger.info("[{}/{}] {}".format(i, img_num, image_file))
img, flag = check_and_read_gif(image_file) img, flag = check_and_read_gif(image_file)
excel_path = os.path.join(excel_save_folder, os.path.basename(image_file).split('.')[0] + '.xlsx') excel_path = os.path.join(args.table_output, os.path.basename(image_file).split('.')[0] + '.xlsx')
if not flag: if not flag:
img = cv2.imread(image_file) img = cv2.imread(image_file)
if img is None: if img is None:
logger.info("error in loading image:{}".format(image_file)) logger.error("error in loading image:{}".format(image_file))
continue continue
starttime = time.time() starttime = time.time()
pred_html = text_sys(img) pred_html = text_sys(img)
...@@ -205,7 +204,7 @@ def main(args): ...@@ -205,7 +204,7 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
args = utility.parse_args() args = parse_args()
if args.use_mp: if args.use_mp:
p_list = [] p_list = []
total_process_num = args.total_process_num total_process_num = args.total_process_num
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from tools.infer.utility import str2bool, init_args as infer_args
def init_args():
parser = infer_args()
# params for output
parser.add_argument("--output", type=str, default='./output/table')
# params for table structure
parser.add_argument("--structure_max_len", type=int, default=488)
parser.add_argument("--structure_max_text_length", type=int, default=100)
parser.add_argument("--structure_max_elem_length", type=int, default=800)
parser.add_argument("--structure_max_cell_num", type=int, default=500)
parser.add_argument("--structure_model_dir", type=str)
parser.add_argument("--structure_char_type", type=str, default='en')
parser.add_argument("--structure_char_dict_path", type=str, default="../ppocr/utils/dict/table_structure_dict.txt")
# params for layout detector
parser.add_argument("--layout_model_dir", type=str)
return parser
def parse_args():
parser = init_args()
return parser.parse_args()
...@@ -88,7 +88,7 @@ class TextSystem(object): ...@@ -88,7 +88,7 @@ class TextSystem(object):
def __call__(self, img, cls=True): def __call__(self, img, cls=True):
ori_im = img.copy() ori_im = img.copy()
dt_boxes, elapse = self.text_detector(img) dt_boxes, elapse = self.text_detector(img)
logger.info("dt_boxes num : {}, elapse : {}".format( logger.debug("dt_boxes num : {}, elapse : {}".format(
len(dt_boxes), elapse)) len(dt_boxes), elapse))
if dt_boxes is None: if dt_boxes is None:
return None, None return None, None
...@@ -103,11 +103,11 @@ class TextSystem(object): ...@@ -103,11 +103,11 @@ class TextSystem(object):
if self.use_angle_cls and cls: if self.use_angle_cls and cls:
img_crop_list, angle_list, elapse = self.text_classifier( img_crop_list, angle_list, elapse = self.text_classifier(
img_crop_list) img_crop_list)
logger.info("cls num : {}, elapse : {}".format( logger.debug("cls num : {}, elapse : {}".format(
len(img_crop_list), elapse)) len(img_crop_list), elapse))
rec_res, elapse = self.text_recognizer(img_crop_list) rec_res, elapse = self.text_recognizer(img_crop_list)
logger.info("rec_res num : {}, elapse : {}".format( logger.debug("rec_res num : {}, elapse : {}".format(
len(rec_res), elapse)) len(rec_res), elapse))
# self.print_draw_crop_rec_res(img_crop_list, rec_res) # self.print_draw_crop_rec_res(img_crop_list, rec_res)
filter_boxes, filter_rec_res = [], [] filter_boxes, filter_rec_res = [], []
...@@ -152,7 +152,7 @@ def main(args): ...@@ -152,7 +152,7 @@ def main(args):
if not flag: if not flag:
img = cv2.imread(image_file) img = cv2.imread(image_file)
if img is None: if img is None:
logger.info("error in loading image:{}".format(image_file)) logger.error("error in loading image:{}".format(image_file))
continue continue
starttime = time.time() starttime = time.time()
dt_boxes, rec_res = text_sys(img) dt_boxes, rec_res = text_sys(img)
......
...@@ -109,7 +109,7 @@ def init_args(): ...@@ -109,7 +109,7 @@ def init_args():
parser.add_argument("--use_mp", type=str2bool, default=False) parser.add_argument("--use_mp", type=str2bool, default=False)
parser.add_argument("--total_process_num", type=int, default=1) parser.add_argument("--total_process_num", type=int, default=1)
parser.add_argument("--process_id", type=int, default=0) parser.add_argument("--process_id", type=int, default=0)
parser.add_argument("--show_log", type=str2bool, default=True)
return parser return parser
...@@ -125,8 +125,8 @@ def create_predictor(args, mode, logger): ...@@ -125,8 +125,8 @@ def create_predictor(args, mode, logger):
model_dir = args.cls_model_dir model_dir = args.cls_model_dir
elif mode == 'rec': elif mode == 'rec':
model_dir = args.rec_model_dir model_dir = args.rec_model_dir
elif mode == 'table': elif mode == 'structure':
model_dir = args.table_model_dir model_dir = args.structure_model_dir
else: else:
model_dir = args.e2e_model_dir model_dir = args.e2e_model_dir
...@@ -246,7 +246,8 @@ def create_predictor(args, mode, logger): ...@@ -246,7 +246,8 @@ def create_predictor(args, mode, logger):
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
config.switch_use_feed_fetch_ops(False) config.switch_use_feed_fetch_ops(False)
if mode == 'table': config.switch_ir_optim(True)
if mode == 'structure':
config.switch_ir_optim(False) config.switch_ir_optim(False)
# create predictor # create predictor
predictor = inference.create_predictor(config) predictor = inference.create_predictor(config)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册