提交 4cf04cbe 编写于 作者: 文幕地方's avatar 文幕地方

fix table recogition benckmark error

上级 4078b0fe
...@@ -68,6 +68,7 @@ def build_pre_process_list(args): ...@@ -68,6 +68,7 @@ def build_pre_process_list(args):
class TableStructurer(object): class TableStructurer(object):
def __init__(self, args): def __init__(self, args):
self.args = args
self.use_onnx = args.use_onnx self.use_onnx = args.use_onnx
pre_process_list = build_pre_process_list(args) pre_process_list = build_pre_process_list(args)
if args.table_algorithm not in ['TableMaster']: if args.table_algorithm not in ['TableMaster']:
...@@ -89,8 +90,31 @@ class TableStructurer(object): ...@@ -89,8 +90,31 @@ class TableStructurer(object):
self.predictor, self.input_tensor, self.output_tensors, self.config = \ self.predictor, self.input_tensor, self.output_tensors, self.config = \
utility.create_predictor(args, 'table', logger) utility.create_predictor(args, 'table', logger)
if args.benchmark:
import auto_log
pid = os.getpid()
gpu_id = utility.get_infer_gpuid()
self.autolog = auto_log.AutoLogger(
model_name="table",
model_precision=args.precision,
batch_size=1,
data_shape="dynamic",
save_path=None, #args.save_log_path,
inference_config=self.config,
pids=pid,
process_name=None,
gpu_ids=gpu_id if args.use_gpu else None,
time_keys=[
'preprocess_time', 'inference_time', 'postprocess_time'
],
warmup=0,
logger=logger)
def __call__(self, img): def __call__(self, img):
starttime = time.time() starttime = time.time()
if self.args.benchmark:
self.autolog.times.start()
ori_im = img.copy() ori_im = img.copy()
data = {'image': img} data = {'image': img}
data = transform(data, self.preprocess_op) data = transform(data, self.preprocess_op)
...@@ -99,6 +123,8 @@ class TableStructurer(object): ...@@ -99,6 +123,8 @@ class TableStructurer(object):
return None, 0 return None, 0
img = np.expand_dims(img, axis=0) img = np.expand_dims(img, axis=0)
img = img.copy() img = img.copy()
if self.args.benchmark:
self.autolog.times.stamp()
if self.use_onnx: if self.use_onnx:
input_dict = {} input_dict = {}
input_dict[self.input_tensor.name] = img input_dict[self.input_tensor.name] = img
...@@ -110,6 +136,8 @@ class TableStructurer(object): ...@@ -110,6 +136,8 @@ class TableStructurer(object):
for output_tensor in self.output_tensors: for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu() output = output_tensor.copy_to_cpu()
outputs.append(output) outputs.append(output)
if self.args.benchmark:
self.autolog.times.stamp()
preds = {} preds = {}
preds['structure_probs'] = outputs[1] preds['structure_probs'] = outputs[1]
...@@ -125,6 +153,8 @@ class TableStructurer(object): ...@@ -125,6 +153,8 @@ class TableStructurer(object):
'<html>', '<body>', '<table>' '<html>', '<body>', '<table>'
] + structure_str_list + ['</table>', '</body>', '</html>'] ] + structure_str_list + ['</table>', '</body>', '</html>']
elapse = time.time() - starttime elapse = time.time() - starttime
if self.args.benchmark:
self.autolog.times.end(stamp=True)
return (structure_str_list, bbox_list), elapse return (structure_str_list, bbox_list), elapse
...@@ -164,6 +194,8 @@ def main(args): ...@@ -164,6 +194,8 @@ def main(args):
total_time += elapse total_time += elapse
count += 1 count += 1
logger.info("Predict time of {}: {}".format(image_file, elapse)) logger.info("Predict time of {}: {}".format(image_file, elapse))
if args.benchmark:
table_structurer.autolog.report()
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
import os import os
import sys import sys
import subprocess
__dir__ = os.path.dirname(os.path.abspath(__file__)) __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__) sys.path.append(__dir__)
...@@ -61,57 +60,31 @@ class TableSystem(object): ...@@ -61,57 +60,31 @@ class TableSystem(object):
self.args = args self.args = args
if not args.show_log: if not args.show_log:
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
args.benchmark = False
self.text_detector = predict_det.TextDetector( self.text_detector = predict_det.TextDetector(copy.deepcopy(
args) if text_detector is None else text_detector args)) if text_detector is None else text_detector
self.text_recognizer = predict_rec.TextRecognizer( self.text_recognizer = predict_rec.TextRecognizer(copy.deepcopy(
args) if text_recognizer is None else text_recognizer args)) if text_recognizer is None else text_recognizer
args.benchmark = True
self.table_structurer = predict_strture.TableStructurer(args) self.table_structurer = predict_strture.TableStructurer(args)
if args.table_algorithm in ['TableMaster']: if args.table_algorithm in ['TableMaster']:
self.match = TableMasterMatcher() self.match = TableMasterMatcher()
else: else:
self.match = TableMatch(filter_ocr_result=True) self.match = TableMatch(filter_ocr_result=True)
self.benchmark = args.benchmark
self.predictor, self.input_tensor, self.output_tensors, self.config = utility.create_predictor( self.predictor, self.input_tensor, self.output_tensors, self.config = utility.create_predictor(
args, 'table', logger) args, 'table', logger)
if args.benchmark:
import auto_log
pid = os.getpid()
gpu_id = utility.get_infer_gpuid()
self.autolog = auto_log.AutoLogger(
model_name="table",
model_precision=args.precision,
batch_size=1,
data_shape="dynamic",
save_path=None, #args.save_log_path,
inference_config=self.config,
pids=pid,
process_name=None,
gpu_ids=gpu_id if args.use_gpu else None,
time_keys=[
'preprocess_time', 'inference_time', 'postprocess_time'
],
warmup=0,
logger=logger)
def __call__(self, img, return_ocr_result_in_table=False): def __call__(self, img, return_ocr_result_in_table=False):
result = dict() result = dict()
time_dict = {'det': 0, 'rec': 0, 'table': 0, 'all': 0, 'match': 0} time_dict = {'det': 0, 'rec': 0, 'table': 0, 'all': 0, 'match': 0}
start = time.time() start = time.time()
if self.args.benchmark:
self.autolog.times.start()
structure_res, elapse = self._structure(copy.deepcopy(img)) structure_res, elapse = self._structure(copy.deepcopy(img))
if self.benchmark:
self.autolog.times.stamp()
result['cell_bbox'] = structure_res[1].tolist() result['cell_bbox'] = structure_res[1].tolist()
time_dict['table'] = elapse time_dict['table'] = elapse
dt_boxes, rec_res, det_elapse, rec_elapse = self._ocr( dt_boxes, rec_res, det_elapse, rec_elapse = self._ocr(
copy.deepcopy(img)) copy.deepcopy(img))
if self.benchmark:
self.autolog.times.stamp()
time_dict['det'] = det_elapse time_dict['det'] = det_elapse
time_dict['rec'] = rec_elapse time_dict['rec'] = rec_elapse
...@@ -126,8 +99,6 @@ class TableSystem(object): ...@@ -126,8 +99,6 @@ class TableSystem(object):
result['html'] = pred_html result['html'] = pred_html
end = time.time() end = time.time()
time_dict['all'] = end - start time_dict['all'] = end - start
if self.benchmark:
self.autolog.times.end(stamp=True)
return result, time_dict return result, time_dict
def _structure(self, img): def _structure(self, img):
...@@ -233,12 +204,13 @@ def main(args): ...@@ -233,12 +204,13 @@ def main(args):
f_html.close() f_html.close()
if args.benchmark: if args.benchmark:
table_sys.autolog.report() table_sys.table_structurer.autolog.report()
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
if args.use_mp: if args.use_mp:
import subprocess
p_list = [] p_list = []
total_process_num = args.total_process_num total_process_num = args.total_process_num
for process_id in range(total_process_num): for process_id in range(total_process_num):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册