提交 dfba983c 编写于 作者: W WenmuZhou

remove unused code

上级 3b81e304
...@@ -382,23 +382,12 @@ class TableLabelDecode(object): ...@@ -382,23 +382,12 @@ class TableLabelDecode(object):
"""convert text-label into text-index. """convert text-label into text-index.
""" """
if char_or_elem == "char": if char_or_elem == "char":
max_len = self.max_text_length
current_dict = self.dict_idx_character current_dict = self.dict_idx_character
else: else:
max_len = self.max_elem_length
current_dict = self.dict_idx_elem current_dict = self.dict_idx_elem
ignored_tokens = self.get_ignored_tokens('elem') ignored_tokens = self.get_ignored_tokens('elem')
beg_idx, end_idx = ignored_tokens beg_idx, end_idx = ignored_tokens
# select_td_tokens = []
# select_span_tokens = []
# for elem in self.dict_elem:
# # if elem == '<td>' or elem == '<td' or elem == '<tr>'\
# # or 'rowspan' in elem or 'colspan' in elem:
# if elem == '<td>' or elem == '<td' or elem == '<tr>':
# select_td_tokens.append(self.dict_elem[elem])
# if 'rowspan' in elem or 'colspan' in elem:
# select_span_tokens.append(self.dict_elem[elem])
result_list = [] result_list = []
result_pos_list = [] result_pos_list = []
result_score_list = [] result_score_list = []
...@@ -415,12 +404,7 @@ class TableLabelDecode(object): ...@@ -415,12 +404,7 @@ class TableLabelDecode(object):
break break
if tmp_elem_idx in ignored_tokens: if tmp_elem_idx in ignored_tokens:
continue continue
# if tmp_elem_idx in select_td_tokens:
# total_td_score += structure_probs[batch_idx, idx]
# total_td_num += 1
# if tmp_elem_idx in select_span_tokens:
# total_span_score += structure_probs[batch_idx, idx]
# total_span_num += 1
char_list.append(current_dict[tmp_elem_idx]) char_list.append(current_dict[tmp_elem_idx])
elem_pos_list.append(idx) elem_pos_list.append(idx)
score_list.append(structure_probs[batch_idx, idx]) score_list.append(structure_probs[batch_idx, idx])
......
...@@ -38,15 +38,15 @@ def main(gt_path, img_root, args): ...@@ -38,15 +38,15 @@ def main(gt_path, img_root, args):
pred_htmls = [] pred_htmls = []
gt_htmls = [] gt_htmls = []
for img_name in tqdm(jsons_gt): for img_name in tqdm(jsons_gt):
# 读取信息 # read image
img = cv2.imread(os.path.join(img_root,img_name)) img = cv2.imread(os.path.join(img_root,img_name))
pred_html = text_sys(img) pred_html = text_sys(img)
pred_htmls.append(pred_html) pred_htmls.append(pred_html)
gt_structures, gt_bboxes, gt_contents, contents_with_block = jsons_gt[img_name] gt_structures, gt_bboxes, gt_contents, contents_with_block = jsons_gt[img_name]
gt_html, gt = get_gt_html(gt_structures, contents_with_block) # 获取HTMLgt gt_html, gt = get_gt_html(gt_structures, contents_with_block)
gt_htmls.append(gt_html) gt_htmls.append(gt_html)
scores = teds.batch_evaluate_html(gt_htmls, pred_htmls) # 计算teds scores = teds.batch_evaluate_html(gt_htmls, pred_htmls)
print('teds:', sum(scores) / len(scores)) print('teds:', sum(scores) / len(scores))
......
...@@ -2,14 +2,9 @@ import json ...@@ -2,14 +2,9 @@ import json
def distance(box_1, box_2): def distance(box_1, box_2):
x1, y1, x2, y2 = box_1 x1, y1, x2, y2 = box_1
x3, y3, x4, y4 = box_2 x3, y3, x4, y4 = box_2
# min_x = (x1 + x2) / 2
# min_y = (y1 + y2) / 2
# max_x = (x3 + x4) / 2
# max_y = (y3 + y4) / 2
dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4- x2) + abs(y4 - y2) dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4- x2) + abs(y4 - y2)
dis_2 = abs(x3 - x1) + abs(y3 - y1) dis_2 = abs(x3 - x1) + abs(y3 - y1)
dis_3 = abs(x4- x2) + abs(y4 - y2) dis_3 = abs(x4- x2) + abs(y4 - y2)
#dis = pow(min_x - max_x, 2) + pow(min_y - max_y, 2) + pow(x3 - x1, 2) + pow(y3 - y1, 2) + pow(x4- x2, 2) + pow(y4 - y2, 2) + abs(x3 - x1) + abs(y3 - y1) + abs(x4- x2) + abs(y4 - y2)
return dis + min(dis_2, dis_3) return dis + min(dis_2, dis_3)
def compute_iou(rec1, rec2): def compute_iou(rec1, rec2):
...@@ -21,7 +16,6 @@ def compute_iou(rec1, rec2): ...@@ -21,7 +16,6 @@ def compute_iou(rec1, rec2):
:return: scala value of IoU :return: scala value of IoU
""" """
# computing area of each rectangles # computing area of each rectangles
rec1, rec2 = rec1 * 1000, rec2 * 1000
S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1]) S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1]) S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
...@@ -36,29 +30,31 @@ def compute_iou(rec1, rec2): ...@@ -36,29 +30,31 @@ def compute_iou(rec1, rec2):
# judge if there is an intersect # judge if there is an intersect
if left_line >= right_line or top_line >= bottom_line: if left_line >= right_line or top_line >= bottom_line:
return 0 return 0.0
else: else:
intersect = (right_line - left_line) * (bottom_line - top_line) intersect = (right_line - left_line) * (bottom_line - top_line)
return (intersect / (sum_area - intersect))*1.0 return (intersect / (sum_area - intersect))*1.0
def matcher_merge(ocr_bboxes, pred_bboxes): # ocr_bboxes: OCR pred_bboxes:端到端 def matcher_merge(ocr_bboxes, pred_bboxes):
all_dis = [] all_dis = []
ious = [] ious = []
matched = {} matched = {}
for i, gt_box in enumerate(ocr_bboxes): for i, gt_box in enumerate(ocr_bboxes):
distances = [] distances = []
for j, pred_box in enumerate(pred_bboxes): for j, pred_box in enumerate(pred_bboxes):
distances.append((distance(gt_box, pred_box), 1. - compute_iou(gt_box, pred_box))) #获取两两cell之间的L1距离和 1- IOU # compute l1 distence and IOU between two boxes
distances.append((distance(gt_box, pred_box), 1. - compute_iou(gt_box, pred_box)))
sorted_distances = distances.copy() sorted_distances = distances.copy()
# 根据距离和IOU挑选最"近"的cell # select nearest cell
sorted_distances = sorted(sorted_distances, key = lambda item: (item[1], item[0])) sorted_distances = sorted(sorted_distances, key = lambda item: (item[1], item[0]))
if distances.index(sorted_distances[0]) not in matched.keys(): if distances.index(sorted_distances[0]) not in matched.keys():
matched[distances.index(sorted_distances[0])] = [i] matched[distances.index(sorted_distances[0])] = [i]
else: else:
matched[distances.index(sorted_distances[0])].append(i) matched[distances.index(sorted_distances[0])].append(i)
return matched#, sum(ious) / len(ious) return matched#, sum(ious) / len(ious)
def complex_num(pred_bboxes): def complex_num(pred_bboxes):
complex_nums = [] complex_nums = []
for bbox in pred_bboxes: for bbox in pred_bboxes:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册