提交 1e27820f 编写于 作者: 文幕地方's avatar 文幕地方

add merge flag

上级 38b3ca47
...@@ -73,12 +73,14 @@ class TableStructurer(object): ...@@ -73,12 +73,14 @@ class TableStructurer(object):
postprocess_params = { postprocess_params = {
'name': 'TableLabelDecode', 'name': 'TableLabelDecode',
"character_dict_path": args.table_char_dict_path, "character_dict_path": args.table_char_dict_path,
'merge_no_span_structure': args.merge_no_span_structure
} }
else: else:
postprocess_params = { postprocess_params = {
'name': 'TableMasterLabelDecode', 'name': 'TableMasterLabelDecode',
"character_dict_path": args.table_char_dict_path, "character_dict_path": args.table_char_dict_path,
'box_shape': 'pad' 'box_shape': 'pad',
'merge_no_span_structure': args.merge_no_span_structure
} }
self.preprocess_op = create_operators(pre_process_list) self.preprocess_op = create_operators(pre_process_list)
......
...@@ -101,6 +101,7 @@ class TableSystem(object): ...@@ -101,6 +101,7 @@ class TableSystem(object):
start = time.time() start = time.time()
structure_res, elapse = self._structure(copy.deepcopy(img)) structure_res, elapse = self._structure(copy.deepcopy(img))
result['cell_bbox'] = structure_res[1]
time_dict['table'] = elapse time_dict['table'] = elapse
dt_boxes, rec_res, det_elapse, rec_elapse = self._ocr( dt_boxes, rec_res, det_elapse, rec_elapse = self._ocr(
...@@ -175,8 +176,23 @@ def main(args): ...@@ -175,8 +176,23 @@ def main(args):
image_file_list = image_file_list[args.process_id::args.total_process_num] image_file_list = image_file_list[args.process_id::args.total_process_num]
os.makedirs(args.output, exist_ok=True) os.makedirs(args.output, exist_ok=True)
text_sys = TableSystem(args) table_sys = TableSystem(args)
img_num = len(image_file_list) img_num = len(image_file_list)
f_html = open(
os.path.join(args.output, 'show.html'), mode='w', encoding='utf-8')
f_html.write('<html>\n<body>\n')
f_html.write('<table border="1">\n')
f_html.write(
"<meta http-equiv=\"Content-Type\" content=\"text/html; charset=utf-8\" />"
)
f_html.write("<tr>\n")
f_html.write('<td>img name\n')
f_html.write('<td>ori image</td>')
f_html.write('<td>table html</td>')
f_html.write('<td>cell box</td>')
f_html.write("</tr>\n")
for i, image_file in enumerate(image_file_list): for i, image_file in enumerate(image_file_list):
logger.info("[{}/{}] {}".format(i, img_num, image_file)) logger.info("[{}/{}] {}".format(i, img_num, image_file))
img, flag = check_and_read_gif(image_file) img, flag = check_and_read_gif(image_file)
...@@ -188,13 +204,31 @@ def main(args): ...@@ -188,13 +204,31 @@ def main(args):
logger.error("error in loading image:{}".format(image_file)) logger.error("error in loading image:{}".format(image_file))
continue continue
starttime = time.time() starttime = time.time()
pred_res, _ = text_sys(img) pred_res, _ = table_sys(img)
pred_html = pred_res['html'] pred_html = pred_res['html']
logger.info(pred_html) logger.info(pred_html)
to_excel(pred_html, excel_path) to_excel(pred_html, excel_path)
logger.info('excel saved to {}'.format(excel_path)) logger.info('excel saved to {}'.format(excel_path))
elapse = time.time() - starttime elapse = time.time() - starttime
logger.info("Predict time : {:.3f}s".format(elapse)) logger.info("Predict time : {:.3f}s".format(elapse))
# img = predict_strture.draw_rectangle(image_file, pred_res['cell_bbox'], use_xywh)
img = utility.draw_boxes(cv2.imread(image_file), pred_res['cell_bbox'])
img_save_path = os.path.join(args.output, os.path.basename(image_file))
cv2.imwrite(img_save_path, img)
f_html.write("<tr>\n")
f_html.write(f'<td> {os.path.basename(image_file)} <br/>\n')
f_html.write(f'<td><img src="{image_file}" width=640></td>\n')
f_html.write('<td><table border="1">' + pred_html.replace(
'<html><body><table>', '').replace('</table></body></html>', '') +
'</table></td>\n')
f_html.write(
f'<td><img src="{os.path.basename(image_file)}" width=640></td>\n')
f_html.write("</tr>\n")
f_html.write("</table>\n")
f_html.close()
if args.benchmark: if args.benchmark:
text_sys.autolog.report() text_sys.autolog.report()
......
...@@ -27,6 +27,8 @@ def init_args(): ...@@ -27,6 +27,8 @@ def init_args():
parser.add_argument("--table_max_len", type=int, default=488) parser.add_argument("--table_max_len", type=int, default=488)
parser.add_argument("--table_algorithm", type=str, default='TableAttn') parser.add_argument("--table_algorithm", type=str, default='TableAttn')
parser.add_argument("--table_model_dir", type=str) parser.add_argument("--table_model_dir", type=str)
parser.add_argument(
"--merge_no_span_structure", type=str2bool, default=False)
parser.add_argument( parser.add_argument(
"--table_char_dict_path", "--table_char_dict_path",
type=str, type=str,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册