From bc9fc148faddf1458736cf43cfbd1caeebb63959 Mon Sep 17 00:00:00 2001 From: LDOUBLEV Date: Fri, 1 Apr 2022 17:06:09 +0800 Subject: [PATCH] delete enc2end --- tools/end2end/convert_ppocr_label.py | 94 ------------- tools/end2end/draw_html.py | 73 ---------- tools/end2end/eval_end2end.py | 193 --------------------------- tools/end2end/readme.md | 69 ---------- 4 files changed, 429 deletions(-) delete mode 100644 tools/end2end/convert_ppocr_label.py delete mode 100644 tools/end2end/draw_html.py delete mode 100644 tools/end2end/eval_end2end.py delete mode 100644 tools/end2end/readme.md diff --git a/tools/end2end/convert_ppocr_label.py b/tools/end2end/convert_ppocr_label.py deleted file mode 100644 index 8084cac7..00000000 --- a/tools/end2end/convert_ppocr_label.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import numpy as np -import json -import os - - -def poly_to_string(poly): - if len(poly.shape) > 1: - poly = np.array(poly).flatten() - - string = "\t".join(str(i) for i in poly) - return string - - -def convert_label(label_dir, mode="gt", save_dir="./save_results/"): - if not os.path.exists(label_dir): - raise ValueError(f"The file {label_dir} does not exist!") - - assert label_dir != save_dir, "hahahhaha" - - label_file = open(label_dir, 'r') - data = label_file.readlines() - - gt_dict = {} - - for line in data: - try: - tmp = line.split('\t') - assert len(tmp) == 2, "" - except: - tmp = line.strip().split(' ') - - gt_lists = [] - - if tmp[0].split('/')[0] is not None: - img_path = tmp[0] - anno = json.loads(tmp[1]) - gt_collect = [] - for dic in anno: - #txt = dic['transcription'].replace(' ', '') # ignore blank - txt = dic['transcription'] - if 'score' in dic and float(dic['score']) < 0.5: - continue - if u'\u3000' in txt: txt = txt.replace(u'\u3000', u' ') - #while ' ' in txt: - # txt = txt.replace(' ', '') - poly = np.array(dic['points']).flatten() - if txt == "###": - txt_tag = 1 ## ignore 1 - else: - txt_tag = 0 - if mode == "gt": - gt_label = poly_to_string(poly) + "\t" + str( - txt_tag) + "\t" + txt + "\n" - else: - gt_label = poly_to_string(poly) + "\t" + txt + "\n" - - gt_lists.append(gt_label) - - gt_dict[img_path] = gt_lists - else: - continue - - if not os.path.exists(save_dir): - os.makedirs(save_dir) - - for img_name in gt_dict.keys(): - save_name = img_name.split("/")[-1] - save_file = os.path.join(save_dir, save_name + ".txt") - with open(save_file, "w") as f: - f.writelines(gt_dict[img_name]) - - print("The convert label saved in {}".format(save_dir)) - - -if __name__ == "__main__": - - ppocr_label_gt = "/paddle/Datasets/chinese/test_set/Label_refine_310_V2.txt" - convert_label(ppocr_label_gt, "gt", "./save_gt_310_V2/") - - ppocr_label_gt = "./infer_results/ch_PPOCRV2_infer.txt" - convert_label(ppocr_label_gt_en, "pred", "./save_PPOCRV2_infer/") diff --git a/tools/end2end/draw_html.py b/tools/end2end/draw_html.py deleted file mode 100644 index fcac8ad3..00000000 --- a/tools/end2end/draw_html.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import argparse - - -def str2bool(v): - return v.lower() in ("true", "t", "1") - - -def init_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--image_dir", type=str, default="") - parser.add_argument("--save_html_path", type=str, default="./default.html") - parser.add_argument("--width", type=int, default=640) - return parser - - -def parse_args(): - parser = init_args() - return parser.parse_args() - - -def draw_debug_img(args): - - html_path = args.save_html_path - - err_cnt = 0 - with open(html_path, 'w') as html: - html.write('\n\n') - html.write('\n') - html.write( - "" - ) - image_list = [] - path = args.image_dir - for i, filename in enumerate(sorted(os.listdir(path))): - if filename.endswith("txt"): continue - # The image path - base = "{}/{}".format(path, filename) - html.write("\n") - html.write(f'') - - html.write("\n") - html.write('\n') - html.write('
{filename}\n GT') - html.write(f'GT\n
\n') - html.write('\n\n') - print(f"The html file saved in {html_path}") - return - - -if __name__ == "__main__": - - args = parse_args() - - draw_debug_img(args) diff --git a/tools/end2end/eval_end2end.py b/tools/end2end/eval_end2end.py deleted file mode 100644 index 6e7573ca..00000000 --- a/tools/end2end/eval_end2end.py +++ /dev/null @@ -1,193 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import re -import sys -import shapely -from shapely.geometry import Polygon -import numpy as np -from collections import defaultdict -import operator -import editdistance - - -def strQ2B(ustring): - rstring = "" - for uchar in ustring: - inside_code = ord(uchar) - if inside_code == 12288: - inside_code = 32 - elif (inside_code >= 65281 and inside_code <= 65374): - inside_code -= 65248 - rstring += chr(inside_code) - return rstring - - -def polygon_from_str(polygon_points): - """ - Create a shapely polygon object from gt or dt line. - """ - polygon_points = np.array(polygon_points).reshape(4, 2) - polygon = Polygon(polygon_points).convex_hull - return polygon - - -def polygon_iou(poly1, poly2): - """ - Intersection over union between two shapely polygons. - """ - if not poly1.intersects( - poly2): # this test is fast and can accelerate calculation - iou = 0 - else: - try: - inter_area = poly1.intersection(poly2).area - union_area = poly1.area + poly2.area - inter_area - iou = float(inter_area) / union_area - except shapely.geos.TopologicalError: - # except Exception as e: - # print(e) - print('shapely.geos.TopologicalError occured, iou set to 0') - iou = 0 - return iou - - -def ed(str1, str2): - return editdistance.eval(str1, str2) - - -def e2e_eval(gt_dir, res_dir, ignore_blank=False): - print('start testing...') - iou_thresh = 0.5 - val_names = os.listdir(gt_dir) - num_gt_chars = 0 - gt_count = 0 - dt_count = 0 - hit = 0 - ed_sum = 0 - - for i, val_name in enumerate(val_names): - with open(os.path.join(gt_dir, val_name), encoding='utf-8') as f: - gt_lines = [o.strip() for o in f.readlines()] - gts = [] - ignore_masks = [] - for line in gt_lines: - parts = line.strip().split('\t') - # ignore illegal data - if len(parts) < 9: - continue - assert (len(parts) < 11) - if len(parts) == 9: - gts.append(parts[:8] + ['']) - else: - gts.append(parts[:8] + [parts[-1]]) - - ignore_masks.append(parts[8]) - - val_path = os.path.join(res_dir, val_name) - if not os.path.exists(val_path): - dt_lines = [] - else: - with open(val_path, encoding='utf-8') as f: - dt_lines = [o.strip() for o in f.readlines()] - dts = [] - for line in dt_lines: - # print(line) - parts = line.strip().split("\t") - assert (len(parts) < 10), "line error: {}".format(line) - if len(parts) == 8: - dts.append(parts + ['']) - else: - dts.append(parts) - - dt_match = [False] * len(dts) - gt_match = [False] * len(gts) - all_ious = defaultdict(tuple) - for index_gt, gt in enumerate(gts): - gt_coors = [float(gt_coor) for gt_coor in gt[0:8]] - gt_poly = polygon_from_str(gt_coors) - for index_dt, dt in enumerate(dts): - dt_coors = [float(dt_coor) for dt_coor in dt[0:8]] - dt_poly = polygon_from_str(dt_coors) - iou = polygon_iou(dt_poly, gt_poly) - if iou >= iou_thresh: - all_ious[(index_gt, index_dt)] = iou - sorted_ious = sorted( - all_ious.items(), key=operator.itemgetter(1), reverse=True) - sorted_gt_dt_pairs = [item[0] for item in sorted_ious] - - # matched gt and dt - for gt_dt_pair in sorted_gt_dt_pairs: - index_gt, index_dt = gt_dt_pair - if gt_match[index_gt] == False and dt_match[index_dt] == False: - gt_match[index_gt] = True - dt_match[index_dt] = True - if ignore_blank: - gt_str = strQ2B(gts[index_gt][8]).replace(" ", "") - dt_str = strQ2B(dts[index_dt][8]).replace(" ", "") - else: - gt_str = strQ2B(gts[index_gt][8]) - dt_str = strQ2B(dts[index_dt][8]) - if ignore_masks[index_gt] == '0': - ed_sum += ed(gt_str, dt_str) - num_gt_chars += len(gt_str) - if gt_str == dt_str: - hit += 1 - gt_count += 1 - dt_count += 1 - - # unmatched dt - for tindex, dt_match_flag in enumerate(dt_match): - if dt_match_flag == False: - dt_str = dts[tindex][8] - gt_str = '' - ed_sum += ed(dt_str, gt_str) - dt_count += 1 - - # unmatched gt - for tindex, gt_match_flag in enumerate(gt_match): - if gt_match_flag == False and ignore_masks[tindex] == '0': - dt_str = '' - gt_str = gts[tindex][8] - ed_sum += ed(gt_str, dt_str) - num_gt_chars += len(gt_str) - gt_count += 1 - - eps = 1e-9 - print('hit, dt_count, gt_count', hit, dt_count, gt_count) - precision = hit / (dt_count + eps) - recall = hit / (gt_count + eps) - fmeasure = 2.0 * precision * recall / (precision + recall + eps) - avg_edit_dist_img = ed_sum / len(val_names) - avg_edit_dist_field = ed_sum / (gt_count + eps) - character_acc = 1 - ed_sum / (num_gt_chars + eps) - - print('character_acc: %.2f' % (character_acc * 100) + "%") - print('avg_edit_dist_field: %.2f' % (avg_edit_dist_field)) - print('avg_edit_dist_img: %.2f' % (avg_edit_dist_img)) - print('precision: %.2f' % (precision * 100) + "%") - print('recall: %.2f' % (recall * 100) + "%") - print('fmeasure: %.2f' % (fmeasure * 100) + "%") - - -if __name__ == '__main__': - # if len(sys.argv) != 3: - # print("python3 ocr_e2e_eval.py gt_dir res_dir") - # exit(-1) - # gt_folder = sys.argv[1] - # pred_folder = sys.argv[2] - gt_folder = sys.argv[1] - pred_folder = sys.argv[2] - e2e_eval(gt_folder, pred_folder) diff --git a/tools/end2end/readme.md b/tools/end2end/readme.md deleted file mode 100644 index 257a6f38..00000000 --- a/tools/end2end/readme.md +++ /dev/null @@ -1,69 +0,0 @@ - -# 简介 - -`tools/end2end`目录下存放了文本检测+文本识别pipeline串联预测的指标评测代码以及可视化工具。本节介绍文本检测+文本识别的端对端指标评估方式。 - - -## 端对端评测步骤 - -**步骤一:** - -运行`tools/infer/predict_system.py`,得到保存的结果: - -``` -python3 tools/infer/predict_system.py --det_model_dir=./ch_PP-OCRv2_det_infer/ --rec_model_dir=./ch_PP-OCRv2_rec_infer/ --image_dir=./datasets/img_dir/ --draw_img_save_dir=./ch_PP-OCRv2_results/ --is_visualize=True -``` - -文本检测识别可视化图默认保存在`./ch_PP-OCRv2_results/`目录下,预测结果默认保存在`./ch_PP-OCRv2_results/results.txt`中,格式如下: -``` -all-sum-510/00224225.jpg [{"transcription": "超赞", "points": [[8.0, 48.0], [157.0, 44.0], [159.0, 115.0], [10.0, 119.0]], "score": "0.99396634"}, {"transcription": "中", "points": [[202.0, 152.0], [230.0, 152.0], [230.0, 163.0], [202.0, 163.0]], "score": "0.09310734"}, {"transcription": "58.0m", "points": [[196.0, 192.0], [444.0, 192.0], [444.0, 240.0], [196.0, 240.0]], "score": "0.44041982"}, {"transcription": "汽配", "points": [[55.0, 263.0], [95.0, 263.0], [95.0, 281.0], [55.0, 281.0]], "score": "0.9986651"}, {"transcription": "成总店", "points": [[120.0, 262.0], [176.0, 262.0], [176.0, 283.0], [120.0, 283.0]], "score": "0.9929402"}, {"transcription": "K", "points": [[237.0, 286.0], [311.0, 286.0], [311.0, 345.0], [237.0, 345.0]], "score": "0.6074794"}, {"transcription": "88:-8", "points": [[203.0, 405.0], [477.0, 414.0], [475.0, 459.0], [201.0, 450.0]], "score": "0.7106863"}] -``` - - -**步骤二:** - -将步骤一保存的数据转换为端对端评测需要的数据格式: -修改 `tools/convert_ppocr_label.py`中的代码,convert_label函数中设置输入标签路径,Mode,保存标签路径等,对预测数据的GTlabel和预测结果的label格式进行转换。 - -``` -ppocr_label_gt = "gt_label.txt" -convert_label(ppocr_label_gt, "gt", "./save_gt_label/") - -ppocr_label_gt = "./infer_results/ch_PPOCRV2_infer.txt" -convert_label(ppocr_label_gt_en, "pred", "./save_PPOCRV2_infer/") -``` - -运行`convert_ppocr_label.py`: -``` -python3 tools/convert_ppocr_label.py -``` - -得到如下结果: -``` -├── ./save_gt_label/ -├── ./save_PPOCRV2_infer/ -``` - -**步骤三:** - -执行端对端评测,运行`tools/eval_end2end.py`计算端对端指标,运行方式如下: - -``` -python3 tools/eval_end2end.py "gt_label_dir" "predict_label_dir" -``` - -比如: - -``` -python3 tools/eval_end2end.py ./save_gt_label/ ./save_PPOCRV2_infer/ -``` -将得到如下结果,fmeasure为主要关注的指标: -``` -hit, dt_count, gt_count 1557 2693 3283 -character_acc: 61.77% -avg_edit_dist_field: 3.08 -avg_edit_dist_img: 51.82 -precision: 57.82% -recall: 47.43% -fmeasure: 52.11% -``` -- GitLab