提交 bc0d7664 编写于 作者: W WenmuZhou

init commit for paddlestructure

上级 a5f75115
include LICENSE.txt
include README.md
recursive-include ppocr/utils *.txt utility.py logging.py
recursive-include ppocr/utils *.txt utility.py logging.py network.py
recursive-include ppocr/data/ *.py
recursive-include ppocr/postprocess *.py
recursive-include tools/infer *.py
......
......@@ -19,6 +19,7 @@ __dir__ = os.path.dirname(__file__)
sys.path.append(os.path.join(__dir__, ''))
import cv2
import logging
import numpy as np
from pathlib import Path
......@@ -150,6 +151,8 @@ class PaddleOCR(predict_system.TextSystem):
"""
params = parse_args(mMain=False)
params.__dict__.update(**kwargs)
if params.show_log:
logger.setLevel(logging.DEBUG)
self.use_angle_cls = params.use_angle_cls
lang = params.lang
latin_lang = [
......
......@@ -33,8 +33,7 @@ D
Π
H
</
>
</strike>
L
Φ
Χ
......
include LICENSE.txt
include README.md
recursive-include ppocr/utils *.txt utility.py logging.py network.py
recursive-include ppocr/data/ *.py
recursive-include ppocr/postprocess *.py
recursive-include tools/infer *.py
recursive-include table *.py
recursive-include ppstructure *.py
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import sys
__dir__ = os.path.dirname(__file__)
sys.path.append(os.path.join(__dir__, ''))
import cv2
import numpy as np
from pathlib import Path
from ppocr.utils.logging import get_logger
from predict_system import OCRSystem, save_res
from utility import init_args
logger = get_logger()
from ppocr.utils.utility import check_and_read_gif, get_image_file_list
from ppocr.utils.network import maybe_download, download_with_progressbar
__all__ = ['PaddleStructure']
VERSION = '2.1'
BASE_DIR = os.path.expanduser("~/.paddlestructure/")
model_urls = {
'det': {
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
},
'rec': {
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
},
'structure': {
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
},
}
def parse_args(mMain=True):
import argparse
parser = init_args()
parser.add_help = mMain
for action in parser._actions:
if action.dest in ['rec_char_dict_path', 'structure_char_dict_path']:
action.default = None
if mMain:
return parser.parse_args()
else:
inference_args_dict = {}
for action in parser._actions:
inference_args_dict[action.dest] = action.default
return argparse.Namespace(**inference_args_dict)
class PaddleStructure(OCRSystem):
def __init__(self, **kwargs):
params = parse_args(mMain=False)
params.__dict__.update(**kwargs)
if params.show_log:
logger.setLevel(logging.DEBUG)
params.use_angle_cls = False
# init model dir
if params.det_model_dir is None:
params.det_model_dir = os.path.join(BASE_DIR, VERSION, 'det')
if params.rec_model_dir is None:
params.rec_model_dir = os.path.join(BASE_DIR, VERSION, 'rec')
if params.structure_model_dir is None:
params.structure_model_dir = os.path.join(BASE_DIR, VERSION, 'structure')
# download model
maybe_download(params.det_model_dir, model_urls['det'])
maybe_download(params.det_model_dir, model_urls['rec'])
maybe_download(params.det_model_dir, model_urls['structure'])
if params.rec_char_dict_path is None:
params.rec_char_type = 'EN'
if os.path.exists(str(Path(__file__).parent / 'ppocr/utils/dict/table_dict.txt')):
params.rec_char_dict_path = str(Path(__file__).parent / 'ppocr/utils/dict/table_dict.txt')
else:
params.rec_char_dict_path = str(Path(__file__).parent.parent / 'ppocr/utils/dict/table_dict.txt')
if params.structure_char_dict_path is None:
if os.path.exists(str(Path(__file__).parent / 'ppocr/utils/dict/table_structure_dict.txt')):
params.structure_char_dict_path = str(Path(__file__).parent / 'ppocr/utils/dict/table_structure_dict.txt')
else:
params.structure_char_dict_path = str(Path(__file__).parent.parent / 'ppocr/utils/dict/table_structure_dict.txt')
print(params)
super().__init__(params)
def __call__(self, img):
if isinstance(img, str):
# download net image
if img.startswith('http'):
download_with_progressbar(img, 'tmp.jpg')
img = 'tmp.jpg'
image_file = img
img, flag = check_and_read_gif(image_file)
if not flag:
with open(image_file, 'rb') as f:
np_arr = np.frombuffer(f.read(), dtype=np.uint8)
img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
if img is None:
logger.error("error in loading image:{}".format(image_file))
return None
if isinstance(img, np.ndarray) and len(img.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
res = super().__call__(img)
return res
def main():
# for cmd
args = parse_args(mMain=True)
image_dir = args.image_dir
save_folder = args.output
if image_dir.startswith('http'):
download_with_progressbar(image_dir, 'tmp.jpg')
image_file_list = ['tmp.jpg']
else:
image_file_list = get_image_file_list(args.image_dir)
if len(image_file_list) == 0:
logger.error('no images find in {}'.format(args.image_dir))
return
structure_engine = PaddleStructure(**(args.__dict__))
for img_path in image_file_list:
img_name = os.path.basename(img_path).split('.')[0]
logger.info('{}{}{}'.format('*' * 10, img_path, '*' * 10))
result = structure_engine(img_path)
save_res(result, args.output, os.path.basename(img_path).split('.')[0])
for item in result:
logger.info(item['res'])
save_res(result, save_folder, img_name)
logger.info('result save to {}'.format(os.path.join(save_folder, img_name)))
if __name__ == '__main__':
table_engine = PaddleStructure(det_model_dir='../inference/table/ch_ppocr_mobile_v2.0_table_det_infer',
rec_model_dir='../inference/table/ch_ppocr_mobile_v2.0_table_rec_infer',
structure_model_dir='../inference/table/ch_ppocr_mobile_v2.0_table_structure_infer',
output='/Users/zhoujun20/Desktop/工作相关/table/table_pr/PaddleOCR/output/table',
show_log=True)
img = cv2.imread('/Users/zhoujun20/Desktop/工作相关/table/table_pr/PaddleOCR/ppstructure/test_imgs/table_1.png')
result = table_engine(img)
for line in result:
print(line)
......@@ -18,97 +18,93 @@ import subprocess
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import cv2
import copy
import numpy as np
import time
import tools.infer.utility as utility
from tools.infer.predict_system import TextSystem
from ppstructure.table.predict_table import TableSystem, to_excel
from ppstructure.layout.predict_layout import LayoutDetector
import layoutparser as lp
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.logging import get_logger
from tools.infer.predict_system import TextSystem
from ppstructure.table.predict_table import TableSystem, to_excel
from ppstructure.utility import parse_args
logger = get_logger()
def parse_args():
parser = utility.init_args()
# params for output
parser.add_argument("--table_output", type=str, default='output/table')
# params for table structure
parser.add_argument("--table_max_len", type=int, default=488)
parser.add_argument("--table_max_text_length", type=int, default=100)
parser.add_argument("--table_max_elem_length", type=int, default=800)
parser.add_argument("--table_max_cell_num", type=int, default=500)
parser.add_argument("--table_model_dir", type=str)
parser.add_argument("--table_char_type", type=str, default='en')
parser.add_argument("--table_char_dict_path", type=str, default="./ppocr/utils/dict/table_structure_dict.txt")
# params for layout detector
parser.add_argument("--layout_model_dir", type=str)
return parser.parse_args()
class OCRSystem():
class OCRSystem(object):
def __init__(self, args):
self.text_system = TextSystem(args)
self.table_system = TableSystem(args)
self.table_layout = LayoutDetector(args)
self.table_system = TableSystem(args, self.text_system.text_detector, self.text_system.text_recognizer)
self.table_layout = lp.PaddleDetectionLayoutModel("lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config",
threshold=0.5, enable_mkldnn=args.enable_mkldnn,
enforce_cpu=not args.use_gpu)
self.use_angle_cls = args.use_angle_cls
self.drop_score = args.drop_score
def __call__(self, img):
ori_im = img.copy()
layout_res = self.table_layout(copy.deepcopy(img))
layout_res = self.table_layout.detect(img[..., ::-1])
res_list = []
for region in layout_res:
x1, y1, x2, y2 = region['bbox']
x1, y1, x2, y2 = region.coordinates
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
roi_img = ori_im[y1:y2, x1:x2, :]
if region['label'] == 'table':
res = self.text_system(roi_img)
if region.type == 'Table':
res = self.table_system(roi_img)
elif region.type == 'Figure':
continue
else:
res = self.text_system(roi_img)
region['res'] = res
return layout_res
filter_boxes, filter_rec_res = self.text_system(roi_img)
filter_boxes = [x.reshape(-1).tolist() for x in filter_boxes]
res = (filter_boxes, filter_rec_res)
res_list.append({'type': region.type, 'bbox': [x1, y1, x2, y2], 'res': res})
return res_list
def save_res(res, save_folder, img_name):
excel_save_folder = os.path.join(save_folder, img_name)
os.makedirs(excel_save_folder, exist_ok=True)
# save res
for region in res:
if region['type'] == 'Table':
excel_path = os.path.join(excel_save_folder, '{}.xlsx'.format(region['bbox']))
to_excel(region['res'], excel_path)
elif region['type'] == 'Figure':
pass
else:
with open(os.path.join(excel_save_folder, 'res.txt'), 'a', encoding='utf8') as f:
for box, rec_res in zip(*region['res']):
f.write('{}\t{}\n'.format(np.array(box).reshape(-1).tolist(), rec_res))
def main(args):
image_file_list = get_image_file_list(args.image_dir)
image_file_list = image_file_list
image_file_list = image_file_list[args.process_id::args.total_process_num]
save_folder = args.table_output
save_folder = args.output
os.makedirs(save_folder, exist_ok=True)
text_sys = OCRSystem(args)
structure_sys = OCRSystem(args)
img_num = len(image_file_list)
for i, image_file in enumerate(image_file_list):
logger.info("[{}/{}] {}".format(i, img_num, image_file))
img, flag = check_and_read_gif(image_file)
img_name = os.path.basename(image_file).split('.')[0]
# excel_path = os.path.join(excel_save_folder, + '.xlsx')
if not flag:
img = cv2.imread(image_file)
if img is None:
logger.info("error in loading image:{}".format(image_file))
logger.error("error in loading image:{}".format(image_file))
continue
starttime = time.time()
res = text_sys(img)
excel_save_folder = os.path.join(save_folder, img_name)
os.makedirs(excel_save_folder, exist_ok=True)
# save res
for region in res:
if region['label'] == 'table':
excel_path = os.path.join(excel_save_folder, '{}.xlsx'.format(region['bbox']))
to_excel(region['res'], excel_path)
else:
with open(os.path.join(excel_save_folder, 'res.txt'),'a',encoding='utf8') as f:
for box, rec_res in zip(*region['res']):
f.write('{}\t{}\n'.format(np.array(box).reshape(-1).tolist(), rec_res))
logger.info(res)
res = structure_sys(img)
save_res(res, save_folder, img_name)
logger.info('result save to {}'.format(os.path.join(save_folder, img_name)))
elapse = time.time() - starttime
logger.info("Predict time : {:.3f}s".format(elapse))
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from setuptools import setup
from io import open
import shutil
with open('../requirements.txt', encoding="utf-8-sig") as f:
requirements = f.readlines()
requirements.append('tqdm')
requirements.append('layoutparser')
def readme():
with open('README_ch.md', encoding="utf-8-sig") as f:
README = f.read()
return README
shutil.copytree('../ppocr','./ppocr')
shutil.copytree('../tools','./tools')
shutil.copytree('../ppstructure','./ppstructure')
setup(
name='paddlestructure',
packages=['paddlestructure'],
package_dir={'paddlestructure': ''},
include_package_data=True,
entry_points={"console_scripts": ["paddlestructure= paddlestructure.paddlestructure:main"]},
version='2.0.6',
install_requires=requirements,
license='Apache License 2.0',
description='Awesome OCR toolkits based on PaddlePaddle (8.6M ultra-lightweight pre-trained model, support training and deployment among server, mobile, embeded and IoT devices',
long_description=readme(),
long_description_content_type='text/markdown',
url='https://github.com/PaddlePaddle/PaddleOCR',
download_url='https://github.com/PaddlePaddle/PaddleOCR.git',
keywords=[
'ocr textdetection textrecognition paddleocr crnn east star-net rosetta ocrlite db chineseocr chinesetextdetection chinesetextrecognition'
],
classifiers=[
'Intended Audience :: Developers', 'Operating System :: OS Independent',
'Natural Language :: Chinese (Simplified)',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.2',
'Programming Language :: Python :: 3.3',
'Programming Language :: Python :: 3.4',
'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7', 'Topic :: Utilities'
], )
shutil.rmtree('ppocr')
shutil.rmtree('tools')
shutil.rmtree('ppstructure')
\ No newline at end of file
# 表格结构和内容预测
先cd到PaddleOCR/ppstructure目录下
预测
```python
python3 table/predict_table.py --det_model_dir=../inference/db --rec_model_dir=../inference/rec_mv3_large1.0/infer --table_model_dir=../inference/explite3/infer --image_dir=../table/imgs/PMC3006023_004_00.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --table_output ../output/table
```
运行完成后,每张图片的excel表格会保存到table_output字段指定的目录下
eval
```python
python3 table/eval_table.py --det_model_dir=../inference/db --rec_model_dir=../inference/rec_mv3_large1.0/infer --table_model_dir=../inference/explite3/infer --image_dir=../table/imgs --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --gt_path=path/to/gt.json
```
......@@ -15,16 +15,21 @@ import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import cv2
import json
from tqdm import tqdm
from ppstructure.table.table_metric import TEDS
from ppstructure.table.predict_table import TableSystem
from ppstructure.predict_system import parse_args
from ppstructure.utility import init_args
def parse_args():
parser = init_args()
parser.add_argument("--gt_path", type=str)
return parser.parse_args()
def main(gt_path, img_root, args):
teds = TEDS(n_jobs=16)
......@@ -33,6 +38,8 @@ def main(gt_path, img_root, args):
pred_htmls = []
gt_htmls = []
for img_name in tqdm(jsons_gt):
if img_name != 'PMC1064865_002_00.png':
continue
# 读取信息
img = cv2.imread(os.path.join(img_root,img_name))
pred_html = text_sys(img)
......@@ -61,6 +68,4 @@ def get_gt_html(gt_structures, contents_with_block):
if __name__ == '__main__':
args = parse_args()
gt_path = 'table/match_code/f_gt_bbox.json'
img_root = 'table/imgs'
main(gt_path,img_root, args)
main(args.gt_path,args.image_dir, args)
......@@ -194,21 +194,3 @@ def matcher_structure(gt_bboxes, pred_bboxes_rows, pred_bboxes):
matched[index].append(i)
pre_bbox = gt_box
return matched
def main():
detect_bboxes = json.load(open('./f_detecion_bbox.json'))
gt_bboxes = json.load(open('./f_gt_bbox.json'))
all_node = 0
matched_right = 0
key = 'PMC4796501_003_00.png'
print(key)
gt_bbox = gt_bboxes[key]
pred_bbox = detect_bboxes[key]
matched = matcher(gt_bbox, pred_bbox)
print(matched)
if __name__ == "__main__":
main()
......@@ -40,7 +40,7 @@ class TableStructurer(object):
def __init__(self, args):
pre_process_list = [{
'ResizeTableImage': {
'max_len': args.table_max_len
'max_len': args.structure_max_len
}
}, {
'NormalizeImage': {
......@@ -60,17 +60,17 @@ class TableStructurer(object):
}]
postprocess_params = {
'name': 'TableLabelDecode',
"character_type": args.table_char_type,
"character_dict_path": args.table_char_dict_path,
"max_text_length": args.table_max_text_length,
"max_elem_length": args.table_max_elem_length,
"max_cell_num": args.table_max_cell_num
"character_type": args.structure_char_type,
"character_dict_path": args.structure_char_dict_path,
"max_text_length": args.structure_max_text_length,
"max_elem_length": args.structure_max_elem_length,
"max_cell_num": args.structure_max_cell_num
}
self.preprocess_op = create_operators(pre_process_list)
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors = \
utility.create_predictor(args, 'table', logger)
utility.create_predictor(args, 'structure', logger)
def __call__(self, img):
ori_im = img.copy()
......
......@@ -18,6 +18,7 @@ import subprocess
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
......@@ -25,13 +26,13 @@ import cv2
import copy
import numpy as np
import time
import tools.infer.utility as utility
import tools.infer.predict_rec as predict_rec
import tools.infer.predict_det as predict_det
import ppstructure.table.predict_structure as predict_strture
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.logging import get_logger
from ppstructure.table.matcher import distance, compute_iou
from matcher import distance, compute_iou
from ppstructure.utility import parse_args
logger = get_logger()
......@@ -52,12 +53,10 @@ def expand(pix, det_box, shape):
class TableSystem(object):
def __init__(self, args):
self.text_detector = predict_det.TextDetector(args)
self.text_recognizer = predict_rec.TextRecognizer(args)
def __init__(self, args, text_detector=None, text_recognizer=None):
self.text_detector = predict_det.TextDetector(args) if text_detector is None else text_detector
self.text_recognizer = predict_rec.TextRecognizer(args) if text_recognizer is None else text_recognizer
self.table_structurer = predict_strture.TableStructurer(args)
self.use_angle_cls = args.use_angle_cls
self.drop_score = args.drop_score
def __call__(self, img):
ori_im = img.copy()
......@@ -75,8 +74,8 @@ class TableSystem(object):
r_boxes.append(box)
dt_boxes = np.array(r_boxes)
# logger.info("dt_boxes num : {}, elapse : {}".format(
# len(dt_boxes), elapse))
logger.debug("dt_boxes num : {}, elapse : {}".format(
len(dt_boxes), elapse))
if dt_boxes is None:
return None, None
img_crop_list = []
......@@ -87,8 +86,8 @@ class TableSystem(object):
text_rect = ori_im[int(y0):int(y1), int(x0):int(x1), :]
img_crop_list.append(text_rect)
rec_res, elapse = self.text_recognizer(img_crop_list)
# logger.info("rec_res num : {}, elapse : {}".format(
# len(rec_res), elapse))
logger.debug("rec_res num : {}, elapse : {}".format(
len(rec_res), elapse))
pred_html, pred = self.rebuild_table(structure_res, dt_boxes, rec_res)
return pred_html
......@@ -172,6 +171,7 @@ def sorted_boxes(dt_boxes):
_boxes[i + 1] = tmp
return _boxes
def to_excel(html_table, excel_path):
from tablepyxl import tablepyxl
tablepyxl.document_to_xl(html_table, excel_path)
......@@ -180,19 +180,18 @@ def to_excel(html_table, excel_path):
def main(args):
image_file_list = get_image_file_list(args.image_dir)
image_file_list = image_file_list[args.process_id::args.total_process_num]
excel_save_folder = 'output/table'
os.makedirs(excel_save_folder, exist_ok=True)
os.makedirs(args.output, exist_ok=True)
text_sys = TableSystem(args)
img_num = len(image_file_list)
for i, image_file in enumerate(image_file_list):
logger.info("[{}/{}] {}".format(i, img_num, image_file))
img, flag = check_and_read_gif(image_file)
excel_path = os.path.join(excel_save_folder, os.path.basename(image_file).split('.')[0] + '.xlsx')
excel_path = os.path.join(args.table_output, os.path.basename(image_file).split('.')[0] + '.xlsx')
if not flag:
img = cv2.imread(image_file)
if img is None:
logger.info("error in loading image:{}".format(image_file))
logger.error("error in loading image:{}".format(image_file))
continue
starttime = time.time()
pred_html = text_sys(img)
......@@ -205,7 +204,7 @@ def main(args):
if __name__ == "__main__":
args = utility.parse_args()
args = parse_args()
if args.use_mp:
p_list = []
total_process_num = args.total_process_num
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from tools.infer.utility import str2bool, init_args as infer_args
def init_args():
parser = infer_args()
# params for output
parser.add_argument("--output", type=str, default='./output/table')
# params for table structure
parser.add_argument("--structure_max_len", type=int, default=488)
parser.add_argument("--structure_max_text_length", type=int, default=100)
parser.add_argument("--structure_max_elem_length", type=int, default=800)
parser.add_argument("--structure_max_cell_num", type=int, default=500)
parser.add_argument("--structure_model_dir", type=str)
parser.add_argument("--structure_char_type", type=str, default='en')
parser.add_argument("--structure_char_dict_path", type=str, default="../ppocr/utils/dict/table_structure_dict.txt")
# params for layout detector
parser.add_argument("--layout_model_dir", type=str)
return parser
def parse_args():
parser = init_args()
return parser.parse_args()
......@@ -88,7 +88,7 @@ class TextSystem(object):
def __call__(self, img, cls=True):
ori_im = img.copy()
dt_boxes, elapse = self.text_detector(img)
logger.info("dt_boxes num : {}, elapse : {}".format(
logger.debug("dt_boxes num : {}, elapse : {}".format(
len(dt_boxes), elapse))
if dt_boxes is None:
return None, None
......@@ -103,11 +103,11 @@ class TextSystem(object):
if self.use_angle_cls and cls:
img_crop_list, angle_list, elapse = self.text_classifier(
img_crop_list)
logger.info("cls num : {}, elapse : {}".format(
logger.debug("cls num : {}, elapse : {}".format(
len(img_crop_list), elapse))
rec_res, elapse = self.text_recognizer(img_crop_list)
logger.info("rec_res num : {}, elapse : {}".format(
logger.debug("rec_res num : {}, elapse : {}".format(
len(rec_res), elapse))
# self.print_draw_crop_rec_res(img_crop_list, rec_res)
filter_boxes, filter_rec_res = [], []
......@@ -152,7 +152,7 @@ def main(args):
if not flag:
img = cv2.imread(image_file)
if img is None:
logger.info("error in loading image:{}".format(image_file))
logger.error("error in loading image:{}".format(image_file))
continue
starttime = time.time()
dt_boxes, rec_res = text_sys(img)
......
......@@ -109,7 +109,7 @@ def init_args():
parser.add_argument("--use_mp", type=str2bool, default=False)
parser.add_argument("--total_process_num", type=int, default=1)
parser.add_argument("--process_id", type=int, default=0)
parser.add_argument("--show_log", type=str2bool, default=True)
return parser
......@@ -125,8 +125,8 @@ def create_predictor(args, mode, logger):
model_dir = args.cls_model_dir
elif mode == 'rec':
model_dir = args.rec_model_dir
elif mode == 'table':
model_dir = args.table_model_dir
elif mode == 'structure':
model_dir = args.structure_model_dir
else:
model_dir = args.e2e_model_dir
......@@ -246,7 +246,8 @@ def create_predictor(args, mode, logger):
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
config.switch_use_feed_fetch_ops(False)
if mode == 'table':
config.switch_ir_optim(True)
if mode == 'structure':
config.switch_ir_optim(False)
# create predictor
predictor = inference.create_predictor(config)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册