diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py
index 594197a6cd862664b17ed8d84c2d7cd908332386..85ce580f95b13539c6aeea32b188bfd3b435d140 100644
--- a/ppocr/postprocess/rec_postprocess.py
+++ b/ppocr/postprocess/rec_postprocess.py
@@ -382,23 +382,12 @@ class TableLabelDecode(object):
"""convert text-label into text-index.
"""
if char_or_elem == "char":
- max_len = self.max_text_length
current_dict = self.dict_idx_character
else:
- max_len = self.max_elem_length
current_dict = self.dict_idx_elem
ignored_tokens = self.get_ignored_tokens('elem')
beg_idx, end_idx = ignored_tokens
- # select_td_tokens = []
- # select_span_tokens = []
- # for elem in self.dict_elem:
- # # if elem == '
' or elem == ' | '\
- # # or 'rowspan' in elem or 'colspan' in elem:
- # if elem == ' | ' or elem == ' | ':
- # select_td_tokens.append(self.dict_elem[elem])
- # if 'rowspan' in elem or 'colspan' in elem:
- # select_span_tokens.append(self.dict_elem[elem])
result_list = []
result_pos_list = []
result_score_list = []
@@ -415,12 +404,7 @@ class TableLabelDecode(object):
break
if tmp_elem_idx in ignored_tokens:
continue
- # if tmp_elem_idx in select_td_tokens:
- # total_td_score += structure_probs[batch_idx, idx]
- # total_td_num += 1
- # if tmp_elem_idx in select_span_tokens:
- # total_span_score += structure_probs[batch_idx, idx]
- # total_span_num += 1
+
char_list.append(current_dict[tmp_elem_idx])
elem_pos_list.append(idx)
score_list.append(structure_probs[batch_idx, idx])
diff --git a/ppstructure/table/eval_table.py b/ppstructure/table/eval_table.py
index 00b9cd51e4b068c727288923c670050c923a23c5..1bcbaa8d0d0b2669828dc6b19c3370a30c522ede 100755
--- a/ppstructure/table/eval_table.py
+++ b/ppstructure/table/eval_table.py
@@ -38,15 +38,15 @@ def main(gt_path, img_root, args):
pred_htmls = []
gt_htmls = []
for img_name in tqdm(jsons_gt):
- # 读取信息
+ # read image
img = cv2.imread(os.path.join(img_root,img_name))
pred_html = text_sys(img)
pred_htmls.append(pred_html)
gt_structures, gt_bboxes, gt_contents, contents_with_block = jsons_gt[img_name]
- gt_html, gt = get_gt_html(gt_structures, contents_with_block) # 获取HTMLgt
+ gt_html, gt = get_gt_html(gt_structures, contents_with_block)
gt_htmls.append(gt_html)
- scores = teds.batch_evaluate_html(gt_htmls, pred_htmls) # 计算teds
+ scores = teds.batch_evaluate_html(gt_htmls, pred_htmls)
print('teds:', sum(scores) / len(scores))
diff --git a/ppstructure/table/matcher.py b/ppstructure/table/matcher.py
index b3c7043029b6a5df9f64d22efb73971b072dc0ca..c3b56384403f5fd92a8db4b4bb378a6d55e5a76c 100755
--- a/ppstructure/table/matcher.py
+++ b/ppstructure/table/matcher.py
@@ -2,14 +2,9 @@ import json
def distance(box_1, box_2):
x1, y1, x2, y2 = box_1
x3, y3, x4, y4 = box_2
- # min_x = (x1 + x2) / 2
- # min_y = (y1 + y2) / 2
- # max_x = (x3 + x4) / 2
- # max_y = (y3 + y4) / 2
dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4- x2) + abs(y4 - y2)
dis_2 = abs(x3 - x1) + abs(y3 - y1)
dis_3 = abs(x4- x2) + abs(y4 - y2)
- #dis = pow(min_x - max_x, 2) + pow(min_y - max_y, 2) + pow(x3 - x1, 2) + pow(y3 - y1, 2) + pow(x4- x2, 2) + pow(y4 - y2, 2) + abs(x3 - x1) + abs(y3 - y1) + abs(x4- x2) + abs(y4 - y2)
return dis + min(dis_2, dis_3)
def compute_iou(rec1, rec2):
@@ -21,7 +16,6 @@ def compute_iou(rec1, rec2):
:return: scala value of IoU
"""
# computing area of each rectangles
- rec1, rec2 = rec1 * 1000, rec2 * 1000
S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
@@ -36,29 +30,31 @@ def compute_iou(rec1, rec2):
# judge if there is an intersect
if left_line >= right_line or top_line >= bottom_line:
- return 0
+ return 0.0
else:
intersect = (right_line - left_line) * (bottom_line - top_line)
return (intersect / (sum_area - intersect))*1.0
-def matcher_merge(ocr_bboxes, pred_bboxes): # ocr_bboxes: OCR pred_bboxes:端到端
+def matcher_merge(ocr_bboxes, pred_bboxes):
all_dis = []
ious = []
matched = {}
for i, gt_box in enumerate(ocr_bboxes):
distances = []
for j, pred_box in enumerate(pred_bboxes):
- distances.append((distance(gt_box, pred_box), 1. - compute_iou(gt_box, pred_box))) #获取两两cell之间的L1距离和 1- IOU
+ # compute l1 distence and IOU between two boxes
+ distances.append((distance(gt_box, pred_box), 1. - compute_iou(gt_box, pred_box)))
sorted_distances = distances.copy()
- # 根据距离和IOU挑选最"近"的cell
+ # select nearest cell
sorted_distances = sorted(sorted_distances, key = lambda item: (item[1], item[0]))
if distances.index(sorted_distances[0]) not in matched.keys():
matched[distances.index(sorted_distances[0])] = [i]
else:
matched[distances.index(sorted_distances[0])].append(i)
return matched#, sum(ious) / len(ious)
+
def complex_num(pred_bboxes):
complex_nums = []
for bbox in pred_bboxes:
|