提交 97f7f748 编写于 作者: 文幕地方's avatar 文幕地方

add copyright

上级 c86c1740
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
conver table label to html
"""
import json
import argparse
from tqdm import tqdm
def save_pred_txt(key, val, tmp_file_path):
with open(tmp_file_path, 'a+', encoding='utf-8') as f:
f.write('{}\t{}\n'.format(key, val))
def skip_char(text, sp_char_list):
"""
skip empty cell
@param text: text in cell
@param sp_char_list: style char and special code
@return:
"""
for sp_char in sp_char_list:
text = text.replace(sp_char, '')
return text
def gen_html(img):
'''
Formats HTML code from tokenized annotation of img
'''
html_code = img['html']['structure']['tokens'].copy()
to_insert = [i for i, tag in enumerate(html_code) if tag in ('<td>', '>')]
for i, cell in zip(to_insert[::-1], img['html']['cells'][::-1]):
if cell['tokens']:
text = ''.join(cell['tokens'])
# skip empty text
sp_char_list = ['<b>', '</b>', '\u2028', ' ', '<i>', '</i>']
text_remove_style = skip_char(text, sp_char_list)
if len(text_remove_style) == 0:
continue
html_code.insert(i + 1, text)
html_code = ''.join(html_code)
html_code = '<html><body><table>{}</table></body></html>'.format(html_code)
return html_code
def load_gt_data(gt_path):
"""
load gt
@param gt_path:
@return:
"""
data_list = {}
with open(gt_path, 'rb') as f:
lines = f.readlines()
for line in tqdm(lines):
data_line = line.decode('utf-8').strip("\n")
info = json.loads(data_line)
data_list[info['filename']] = info
return data_list
def convert(origin_gt_path, save_path):
"""
gen html from label file
@param origin_gt_path:
@param save_path:
@return:
"""
data_dict = load_gt_data(origin_gt_path)
for img_name, gt in tqdm(data_dict.items()):
html = gen_html(gt)
save_pred_txt(img_name, html, save_path)
print('conver finish')
def parse_args():
parser = argparse.ArgumentParser(description="args for paddleserving")
parser.add_argument(
"--ori_gt_path", type=str, required=True, help="label gt path")
parser.add_argument(
"--save_path", type=str, required=True, help="path to save file")
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
convert(args.ori_gt_path, args.save_path)
import json # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from ppstructure.table.table_master_match import deal_eb_token, deal_bb from ppstructure.table.table_master_match import deal_eb_token, deal_bb
...@@ -64,6 +78,11 @@ class TableMatch: ...@@ -64,6 +78,11 @@ class TableMatch:
for i, gt_box in enumerate(dt_boxes): for i, gt_box in enumerate(dt_boxes):
distances = [] distances = []
for j, pred_box in enumerate(pred_bboxes): for j, pred_box in enumerate(pred_bboxes):
if len(pred_box) == 8:
pred_box = [
np.min(pred_box[0::2]), np.min(pred_box[1::2]),
np.max(pred_box[0::2]), np.max(pred_box[1::2])
]
distances.append((distance(gt_box, pred_box), distances.append((distance(gt_box, pred_box),
1. - compute_iou(gt_box, pred_box) 1. - compute_iou(gt_box, pred_box)
)) # compute iou and l1 distance )) # compute iou and l1 distance
......
...@@ -133,6 +133,7 @@ class TableSystem(object): ...@@ -133,6 +133,7 @@ class TableSystem(object):
return structure_res, elapse return structure_res, elapse
def _ocr(self, img): def _ocr(self, img):
h, w = img.shape[:2]
if self.benchmark: if self.benchmark:
self.autolog.times.stamp() self.autolog.times.stamp()
dt_boxes, det_elapse = self.text_detector(copy.deepcopy(img)) dt_boxes, det_elapse = self.text_detector(copy.deepcopy(img))
...@@ -140,10 +141,10 @@ class TableSystem(object): ...@@ -140,10 +141,10 @@ class TableSystem(object):
r_boxes = [] r_boxes = []
for box in dt_boxes: for box in dt_boxes:
x_min = box[:, 0].min() - 1 x_min = max(0, box[:, 0].min() - 1)
x_max = box[:, 0].max() + 1 x_max = min(w, box[:, 0].max() + 1)
y_min = box[:, 1].min() - 1 y_min = max(0, box[:, 1].min() - 1)
y_max = box[:, 1].max() + 1 y_max = min(h, box[:, 1].max() + 1)
box = [x_min, y_min, x_max, y_max] box = [x_min, y_min, x_max, y_max]
r_boxes.append(box) r_boxes.append(box)
dt_boxes = np.array(r_boxes) dt_boxes = np.array(r_boxes)
......
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/JiaquanYe/TableMASTER-mmocr/blob/master/table_recognition/match.py
"""
import os import os
import re import re
import cv2 import cv2
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册