diff --git a/configs/table/SLANet.yml b/configs/table/SLANet.yml index 105fb5fad287ba014a16c1138f5f9a3a25ad609f..acf8d0304558816d206a3c7f37de4aaba301683b 100644 --- a/configs/table/SLANet.yml +++ b/configs/table/SLANet.yml @@ -1,6 +1,6 @@ Global: use_gpu: true - epoch_num: 400 + epoch_num: 100 log_smooth_window: 20 print_batch_step: 20 save_model_dir: ./output/SLANet @@ -28,7 +28,10 @@ Optimizer: beta2: 0.999 clip_norm: 5.0 lr: + name: Piecewise learning_rate: 0.001 + decay_epochs : [40, 50] + values : [0.001, 0.0001, 0.00005] regularizer: name: 'L2' factor: 0.00000 @@ -105,8 +108,8 @@ Train: Eval: dataset: name: PubTabDataSet - data_dir: /home/zhoujun20/table/PubTabNe/pubtabnet/val/ - label_file_list: [/home/zhoujun20/table/PubTabNe/pubtabnet/val_500.jsonl] + data_dir: train_data/table/pubtabnet/val/ + label_file_list: [train_data/table/pubtabnet/PubTabNet_2.0.0_val.jsonl] transforms: - DecodeImage: # load image img_mode: BGR diff --git a/ppstructure/table/matcher.py b/ppstructure/table/matcher.py index d75e9abb341be4eae2a7b7cfc4995fa95543f222..a25f869c9a63817c2b2ad9c26e9c54c2c4c24f5b 100755 --- a/ppstructure/table/matcher.py +++ b/ppstructure/table/matcher.py @@ -40,169 +40,6 @@ def compute_iou(rec1, rec2): return (intersect / (sum_area - intersect)) * 1.0 -def matcher_merge(ocr_bboxes, pred_bboxes): - all_dis = [] - ious = [] - matched = {} - for i, gt_box in enumerate(ocr_bboxes): - distances = [] - for j, pred_box in enumerate(pred_bboxes): - # compute l1 distence and IOU between two boxes - distances.append((distance(gt_box, pred_box), - 1. - compute_iou(gt_box, pred_box))) - sorted_distances = distances.copy() - # select nearest cell - sorted_distances = sorted( - sorted_distances, key=lambda item: (item[1], item[0])) - if distances.index(sorted_distances[0]) not in matched.keys(): - matched[distances.index(sorted_distances[0])] = [i] - else: - matched[distances.index(sorted_distances[0])].append(i) - return matched #, sum(ious) / len(ious) - - -def complex_num(pred_bboxes): - complex_nums = [] - for bbox in pred_bboxes: - distances = [] - temp_ious = [] - for pred_bbox in pred_bboxes: - if bbox != pred_bbox: - distances.append(distance(bbox, pred_bbox)) - temp_ious.append(compute_iou(bbox, pred_bbox)) - complex_nums.append(temp_ious[distances.index(min(distances))]) - return sum(complex_nums) / len(complex_nums) - - -def get_rows(pred_bboxes): - pre_bbox = pred_bboxes[0] - res = [] - step = 0 - for i in range(len(pred_bboxes)): - bbox = pred_bboxes[i] - if bbox[1] - pre_bbox[1] > 2 or bbox[0] - pre_bbox[0] < 0: - break - else: - res.append(bbox) - step += 1 - for i in range(step): - pred_bboxes.pop(0) - return res, pred_bboxes - - -def refine_rows(pred_bboxes): # 微调整行的框,使在一条水平线上 - ys_1 = [] - ys_2 = [] - for box in pred_bboxes: - ys_1.append(box[1]) - ys_2.append(box[3]) - min_y_1 = sum(ys_1) / len(ys_1) - min_y_2 = sum(ys_2) / len(ys_2) - re_boxes = [] - for box in pred_bboxes: - box[1] = min_y_1 - box[3] = min_y_2 - re_boxes.append(box) - return re_boxes - - -def matcher_refine_row(gt_bboxes, pred_bboxes): - before_refine_pred_bboxes = pred_bboxes.copy() - pred_bboxes = [] - while (len(before_refine_pred_bboxes) != 0): - row_bboxes, before_refine_pred_bboxes = get_rows( - before_refine_pred_bboxes) - print(row_bboxes) - pred_bboxes.extend(refine_rows(row_bboxes)) - all_dis = [] - ious = [] - matched = {} - for i, gt_box in enumerate(gt_bboxes): - distances = [] - #temp_ious = [] - for j, pred_box in enumerate(pred_bboxes): - distances.append(distance(gt_box, pred_box)) - #temp_ious.append(compute_iou(gt_box, pred_box)) - #all_dis.append(min(distances)) - #ious.append(temp_ious[distances.index(min(distances))]) - if distances.index(min(distances)) not in matched.keys(): - matched[distances.index(min(distances))] = [i] - else: - matched[distances.index(min(distances))].append(i) - return matched #, sum(ious) / len(ious) - - -#先挑选出一行,再进行匹配 -def matcher_structure_1(gt_bboxes, pred_bboxes_rows, pred_bboxes): - gt_box_index = 0 - delete_gt_bboxes = gt_bboxes.copy() - match_bboxes_ready = [] - matched = {} - while (len(delete_gt_bboxes) != 0): - row_bboxes, delete_gt_bboxes = get_rows(delete_gt_bboxes) - row_bboxes = sorted(row_bboxes, key=lambda key: key[0]) - if len(pred_bboxes_rows) > 0: - match_bboxes_ready.extend(pred_bboxes_rows.pop(0)) - print(row_bboxes) - for i, gt_box in enumerate(row_bboxes): - #print(gt_box) - pred_distances = [] - distances = [] - for pred_bbox in pred_bboxes: - pred_distances.append(distance(gt_box, pred_bbox)) - for j, pred_box in enumerate(match_bboxes_ready): - distances.append(distance(gt_box, pred_box)) - index = pred_distances.index(min(distances)) - #print('index', index) - if index not in matched.keys(): - matched[index] = [gt_box_index] - else: - matched[index].append(gt_box_index) - gt_box_index += 1 - return matched - - -def matcher_structure(gt_bboxes, pred_bboxes_rows, pred_bboxes): - ''' - gt_bboxes: 排序后 - pred_bboxes: - ''' - pre_bbox = gt_bboxes[0] - matched = {} - match_bboxes_ready = [] - match_bboxes_ready.extend(pred_bboxes_rows.pop(0)) - for i, gt_box in enumerate(gt_bboxes): - - pred_distances = [] - for pred_bbox in pred_bboxes: - pred_distances.append(distance(gt_box, pred_bbox)) - distances = [] - gap_pre = gt_box[1] - pre_bbox[1] - gap_pre_1 = gt_box[0] - pre_bbox[2] - #print(gap_pre, len(pred_bboxes_rows)) - if (gap_pre_1 < 0 and len(pred_bboxes_rows) > 0): - match_bboxes_ready.extend(pred_bboxes_rows.pop(0)) - if len(pred_bboxes_rows) == 1: - match_bboxes_ready.extend(pred_bboxes_rows.pop(0)) - if len(match_bboxes_ready) == 0 and len(pred_bboxes_rows) > 0: - match_bboxes_ready.extend(pred_bboxes_rows.pop(0)) - if len(match_bboxes_ready) == 0 and len(pred_bboxes_rows) == 0: - break - #print(match_bboxes_ready) - for j, pred_box in enumerate(match_bboxes_ready): - distances.append(distance(gt_box, pred_box)) - index = pred_distances.index(min(distances)) - #print(gt_box, index) - #match_bboxes_ready.pop(distances.index(min(distances))) - print(gt_box, match_bboxes_ready[distances.index(min(distances))]) - if index not in matched.keys(): - matched[index] = [i] - else: - matched[index].append(i) - pre_bbox = gt_box - return matched - - class TableMatch: def __init__(self, filter_ocr_result=False, use_master=False): self.filter_ocr_result = filter_ocr_result @@ -225,14 +62,13 @@ class TableMatch: def match_result(self, dt_boxes, pred_bboxes): matched = {} for i, gt_box in enumerate(dt_boxes): - # gt_box = [np.min(gt_box[:, 0]), np.min(gt_box[:, 1]), np.max(gt_box[:, 0]), np.max(gt_box[:, 1])] distances = [] for j, pred_box in enumerate(pred_bboxes): distances.append((distance(gt_box, pred_box), 1. - compute_iou(gt_box, pred_box) - )) # 获取两两cell之间的L1距离和 1- IOU + )) # compute iou and l1 distance sorted_distances = distances.copy() - # 根据距离和IOU挑选最"近"的cell + # select det box by iou and l1 distance sorted_distances = sorted( sorted_distances, key=lambda item: (item[1], item[0])) if distances.index(sorted_distances[0]) not in matched.keys():