From e2b84da86655eec1b03f320c5a75174864355eba Mon Sep 17 00:00:00 2001 From: Jethong <1147925384@qq.com> Date: Tue, 20 Apr 2021 19:47:19 +0800 Subject: [PATCH] add eval mode --- configs/e2e/e2e_r50_vd_pg.yml | 6 +- ppocr/data/imaug/label_ops.py | 28 +++++- ppocr/metrics/e2e_metric.py | 51 +++++++++-- ppocr/utils/e2e_metric/Deteval.py | 139 +++++++++++++++++++++++++++++- 4 files changed, 212 insertions(+), 12 deletions(-) diff --git a/configs/e2e/e2e_r50_vd_pg.yml b/configs/e2e/e2e_r50_vd_pg.yml index 0bacf12d..205a46a5 100644 --- a/configs/e2e/e2e_r50_vd_pg.yml +++ b/configs/e2e/e2e_r50_vd_pg.yml @@ -62,6 +62,7 @@ PostProcess: mode: fast # fast or slow two ways Metric: name: E2EMetric + mode: A # A or B gt_mat_dir: ./train_data/total_text/gt # the dir of gt_mat character_dict_path: ppocr/utils/ic15_dict.txt main_indicator: f_score_e2e @@ -76,7 +77,7 @@ Train: - DecodeImage: # load image img_mode: BGR channel_first: False - - E2ELabelEncode: + - E2ELabelEncode_train: - PGProcessTrain: batch_size: 14 # same as loader: batch_size_per_card min_crop_size: 24 @@ -99,6 +100,7 @@ Eval: - DecodeImage: # load image img_mode: RGB channel_first: False + - E2ELabelEncode_test: - E2EResizeForTest: max_side_len: 768 - NormalizeImage: @@ -108,7 +110,7 @@ Eval: order: 'hwc' - ToCHWImage: - KeepKeys: - keep_keys: [ 'image', 'shape', 'img_id'] + keep_keys: [ 'image', 'shape', 'polys', 'texts', 'tags', 'img_id'] loader: shuffle: False drop_last: False diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index 61cc7303..aae37ad9 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -187,7 +187,33 @@ class CTCLabelEncode(BaseRecLabelEncode): return dict_character -class E2ELabelEncode(object): +class E2ELabelEncode_test(BaseRecLabelEncode): + def __init__(self, + max_text_length, + character_dict_path=None, + character_type='EN', + use_space_char=False, + **kwargs): + super(E2ELabelEncode_test, + self).__init__(max_text_length, character_dict_path, + character_type, use_space_char) + + def __call__(self, data): + texts = data['texts'] + temp_texts = [] + for text in texts: + text = text.lower() + text = self.encode(text) + if text is None: + return None + text = text + [36] * (self.max_text_len - len(text) + ) # use 36 to pad + temp_texts.append(text) + data['texts'] = np.array(temp_texts) + return data + + +class E2ELabelEncode_train(object): def __init__(self, **kwargs): pass diff --git a/ppocr/metrics/e2e_metric.py b/ppocr/metrics/e2e_metric.py index 525aa003..aeef43f9 100644 --- a/ppocr/metrics/e2e_metric.py +++ b/ppocr/metrics/e2e_metric.py @@ -18,16 +18,18 @@ from __future__ import print_function __all__ = ['E2EMetric'] -from ppocr.utils.e2e_metric.Deteval import get_socre, combine_results +from ppocr.utils.e2e_metric.Deteval import get_socre_A, get_socre_B, combine_results from ppocr.utils.e2e_utils.extract_textpoint_slow import get_dict class E2EMetric(object): def __init__(self, + mode, gt_mat_dir, character_dict_path, main_indicator='f_score_e2e', **kwargs): + self.mode = mode self.gt_mat_dir = gt_mat_dir self.label_list = get_dict(character_dict_path) self.max_index = len(self.label_list) @@ -35,13 +37,46 @@ class E2EMetric(object): self.reset() def __call__(self, preds, batch, **kwargs): - img_id = batch[2][0] - e2e_info_list = [{ - 'points': det_polyon, - 'texts': pred_str - } for det_polyon, pred_str in zip(preds['points'], preds['texts'])] - result = get_socre(self.gt_mat_dir, img_id, e2e_info_list) - self.results.append(result) + if self.mode == 'A': + gt_polyons_batch = batch[2] + temp_gt_strs_batch = batch[3] + ignore_tags_batch = batch[4] + gt_strs_batch = [] + + for temp_list in temp_gt_strs_batch: + t = "" + for index in temp_list: + if index < self.max_index: + t += self.label_list[index] + gt_strs_batch.append(t) + + for pred, gt_polyons, gt_strs, ignore_tags in zip( + [preds], [gt_polyons_batch], [gt_strs_batch], + ignore_tags_batch): + # prepare gt + gt_info_list = [{ + 'points': gt_polyon, + 'text': gt_str, + 'ignore': ignore_tag + } for gt_polyon, gt_str, ignore_tag in + zip(gt_polyons, gt_strs, ignore_tags)] + # prepare det + e2e_info_list = [{ + 'points': det_polyon, + 'texts': pred_str + } for det_polyon, pred_str in + zip(pred['points'], pred['texts'])] + + result = get_socre_A(gt_info_list, e2e_info_list) + self.results.append(result) + else: + img_id = batch[5][0] + e2e_info_list = [{ + 'points': det_polyon, + 'texts': pred_str + } for det_polyon, pred_str in zip(preds['points'], preds['texts'])] + result = get_socre_B(self.gt_mat_dir, img_id, e2e_info_list) + self.results.append(result) def get_metric(self): metircs = combine_results(self.results) diff --git a/ppocr/utils/e2e_metric/Deteval.py b/ppocr/utils/e2e_metric/Deteval.py index 2aa09304..45567a7d 100755 --- a/ppocr/utils/e2e_metric/Deteval.py +++ b/ppocr/utils/e2e_metric/Deteval.py @@ -17,7 +17,144 @@ import scipy.io as io from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area -def get_socre(gt_dir, img_id, pred_dict): +def get_socre_A(gt_dir, pred_dict): + allInputs = 1 + + def input_reading_mod(pred_dict): + """This helper reads input from txt files""" + det = [] + n = len(pred_dict) + for i in range(n): + points = pred_dict[i]['points'] + text = pred_dict[i]['texts'] + point = ",".join(map(str, points.reshape(-1, ))) + det.append([point, text]) + return det + + def gt_reading_mod(gt_dict): + """This helper reads groundtruths from mat files""" + gt = [] + n = len(gt_dict) + for i in range(n): + points = gt_dict[i]['points'].tolist() + h = len(points) + text = gt_dict[i]['text'] + xx = [ + np.array( + ['x:'], dtype=' 1): + gt_x = list(map(int, np.squeeze(gt[1]))) + gt_y = list(map(int, np.squeeze(gt[3]))) + for det_id, detection in enumerate(detections): + detection_orig = detection + detection = [float(x) for x in detection[0].split(',')] + detection = list(map(int, detection)) + det_x = detection[0::2] + det_y = detection[1::2] + det_gt_iou = iod(det_x, det_y, gt_x, gt_y) + if det_gt_iou > threshold: + detections[det_id] = [] + + detections[:] = [item for item in detections if item != []] + return detections + + def sigma_calculation(det_x, det_y, gt_x, gt_y): + """ + sigma = inter_area / gt_area + """ + return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) / + area(gt_x, gt_y)), 2) + + def tau_calculation(det_x, det_y, gt_x, gt_y): + if area(det_x, det_y) == 0.0: + return 0 + return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) / + area(det_x, det_y)), 2) + + ##############################Initialization################################### + # global_sigma = [] + # global_tau = [] + # global_pred_str = [] + # global_gt_str = [] + ############################################################################### + + for input_id in range(allInputs): + if (input_id != '.DS_Store') and (input_id != 'Pascal_result.txt') and ( + input_id != 'Pascal_result_curved.txt') and (input_id != 'Pascal_result_non_curved.txt') and ( + input_id != 'Deteval_result.txt') and (input_id != 'Deteval_result_curved.txt') \ + and (input_id != 'Deteval_result_non_curved.txt'): + detections = input_reading_mod(pred_dict) + groundtruths = gt_reading_mod(gt_dir) + detections = detection_filtering( + detections, + groundtruths) # filters detections overlapping with DC area + dc_id = [] + for i in range(len(groundtruths)): + if groundtruths[i][5] == '#': + dc_id.append(i) + cnt = 0 + for a in dc_id: + num = a - cnt + del groundtruths[num] + cnt += 1 + + local_sigma_table = np.zeros((len(groundtruths), len(detections))) + local_tau_table = np.zeros((len(groundtruths), len(detections))) + local_pred_str = {} + local_gt_str = {} + + for gt_id, gt in enumerate(groundtruths): + if len(detections) > 0: + for det_id, detection in enumerate(detections): + detection_orig = detection + detection = [float(x) for x in detection[0].split(',')] + detection = list(map(int, detection)) + pred_seq_str = detection_orig[1].strip() + det_x = detection[0::2] + det_y = detection[1::2] + gt_x = list(map(int, np.squeeze(gt[1]))) + gt_y = list(map(int, np.squeeze(gt[3]))) + gt_seq_str = str(gt[4].tolist()[0]) + + local_sigma_table[gt_id, det_id] = sigma_calculation( + det_x, det_y, gt_x, gt_y) + local_tau_table[gt_id, det_id] = tau_calculation( + det_x, det_y, gt_x, gt_y) + local_pred_str[det_id] = pred_seq_str + local_gt_str[gt_id] = gt_seq_str + + global_sigma = local_sigma_table + global_tau = local_tau_table + global_pred_str = local_pred_str + global_gt_str = local_gt_str + + single_data = {} + single_data['sigma'] = global_sigma + single_data['global_tau'] = global_tau + single_data['global_pred_str'] = global_pred_str + single_data['global_gt_str'] = global_gt_str + return single_data + + +def get_socre_B(gt_dir, img_id, pred_dict): allInputs = 1 def input_reading_mod(pred_dict): -- GitLab