提交 296809a0 编写于 作者: A andyjpaddle

debug_table

上级 8a5a9870
...@@ -28,6 +28,7 @@ import numpy as np ...@@ -28,6 +28,7 @@ import numpy as np
import time import time
import tools.infer.predict_rec as predict_rec import tools.infer.predict_rec as predict_rec
import tools.infer.predict_det as predict_det import tools.infer.predict_det as predict_det
import tools.infer.utility as utility
from ppocr.utils.utility import get_image_file_list, check_and_read_gif from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
from ppstructure.table.matcher import distance, compute_iou from ppstructure.table.matcher import distance, compute_iou
...@@ -59,11 +60,37 @@ class TableSystem(object): ...@@ -59,11 +60,37 @@ class TableSystem(object):
self.text_recognizer = predict_rec.TextRecognizer( self.text_recognizer = predict_rec.TextRecognizer(
args) if text_recognizer is None else text_recognizer args) if text_recognizer is None else text_recognizer
self.table_structurer = predict_strture.TableStructurer(args) self.table_structurer = predict_strture.TableStructurer(args)
self.benchmark = args.benchmark
self.predictor, self.input_tensor, self.output_tensors, self.config = utility.create_predictor(
args, 'table', logger)
if args.benchmark:
import auto_log
pid = os.getpid()
gpu_id = utility.get_infer_gpuid()
self.autolog = auto_log.AutoLogger(
model_name="table",
model_precision=args.precision,
batch_size=1,
data_shape="dynamic",
save_path=None, #args.save_log_path,
inference_config=self.config,
pids=pid,
process_name=None,
gpu_ids=gpu_id if args.use_gpu else None,
time_keys=[
'preprocess_time', 'inference_time', 'postprocess_time'
],
warmup=0,
logger=logger)
def __call__(self, img, return_ocr_result_in_table=False): def __call__(self, img, return_ocr_result_in_table=False):
result = dict() result = dict()
ori_im = img.copy() ori_im = img.copy()
if self.benchmark:
self.autolog.times.start()
structure_res, elapse = self.table_structurer(copy.deepcopy(img)) structure_res, elapse = self.table_structurer(copy.deepcopy(img))
if self.benchmark:
self.autolog.times.stamp()
dt_boxes, elapse = self.text_detector(copy.deepcopy(img)) dt_boxes, elapse = self.text_detector(copy.deepcopy(img))
dt_boxes = sorted_boxes(dt_boxes) dt_boxes = sorted_boxes(dt_boxes)
if return_ocr_result_in_table: if return_ocr_result_in_table:
...@@ -77,13 +104,11 @@ class TableSystem(object): ...@@ -77,13 +104,11 @@ class TableSystem(object):
box = [x_min, y_min, x_max, y_max] box = [x_min, y_min, x_max, y_max]
r_boxes.append(box) r_boxes.append(box)
dt_boxes = np.array(r_boxes) dt_boxes = np.array(r_boxes)
logger.debug("dt_boxes num : {}, elapse : {}".format( logger.debug("dt_boxes num : {}, elapse : {}".format(
len(dt_boxes), elapse)) len(dt_boxes), elapse))
if dt_boxes is None: if dt_boxes is None:
return None, None return None, None
img_crop_list = [] img_crop_list = []
for i in range(len(dt_boxes)): for i in range(len(dt_boxes)):
det_box = dt_boxes[i] det_box = dt_boxes[i]
x0, y0, x1, y1 = expand(2, det_box, ori_im.shape) x0, y0, x1, y1 = expand(2, det_box, ori_im.shape)
...@@ -92,10 +117,14 @@ class TableSystem(object): ...@@ -92,10 +117,14 @@ class TableSystem(object):
rec_res, elapse = self.text_recognizer(img_crop_list) rec_res, elapse = self.text_recognizer(img_crop_list)
logger.debug("rec_res num : {}, elapse : {}".format( logger.debug("rec_res num : {}, elapse : {}".format(
len(rec_res), elapse)) len(rec_res), elapse))
if self.benchmark:
self.autolog.times.stamp()
if return_ocr_result_in_table: if return_ocr_result_in_table:
result['rec_res'] = rec_res result['rec_res'] = rec_res
pred_html, pred = self.rebuild_table(structure_res, dt_boxes, rec_res) pred_html, pred = self.rebuild_table(structure_res, dt_boxes, rec_res)
result['html'] = pred_html result['html'] = pred_html
if self.benchmark:
self.autolog.times.end(stamp=True)
return result return result
def rebuild_table(self, structure_res, dt_boxes, rec_res): def rebuild_table(self, structure_res, dt_boxes, rec_res):
...@@ -213,11 +242,15 @@ def main(args): ...@@ -213,11 +242,15 @@ def main(args):
logger.info('excel saved to {}'.format(excel_path)) logger.info('excel saved to {}'.format(excel_path))
elapse = time.time() - starttime elapse = time.time() - starttime
logger.info("Predict time : {:.3f}s".format(elapse)) logger.info("Predict time : {:.3f}s".format(elapse))
if args.benchmark:
text_sys.autolog.report()
print('ok')
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
if args.use_mp: if args.use_mp:
print('mp')
p_list = [] p_list = []
total_process_num = args.total_process_num total_process_num = args.total_process_num
for process_id in range(total_process_num): for process_id in range(total_process_num):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册