From dbaab10e882d358536148edc3525f080b55e2d38 Mon Sep 17 00:00:00 2001 From: LDOUBLEV Date: Fri, 25 Mar 2022 10:11:08 +0800 Subject: [PATCH] add end2end --- tools/end2end/convert_ppocr_label.py | 81 ++++++++++++ tools/end2end/draw_html.py | 52 ++++++++ tools/end2end/eval_end2end.py | 182 +++++++++++++++++++++++++++ tools/end2end/readme.md | 69 ++++++++++ tools/infer/predict_system.py | 42 ++++++- tools/infer/utility.py | 1 + 6 files changed, 426 insertions(+), 1 deletion(-) create mode 100644 tools/end2end/convert_ppocr_label.py create mode 100644 tools/end2end/draw_html.py create mode 100644 tools/end2end/eval_end2end.py create mode 100644 tools/end2end/readme.md diff --git a/tools/end2end/convert_ppocr_label.py b/tools/end2end/convert_ppocr_label.py new file mode 100644 index 00000000..a230a14a --- /dev/null +++ b/tools/end2end/convert_ppocr_label.py @@ -0,0 +1,81 @@ +import numpy as np +import json +import os + + +def poly_to_string(poly): + if len(poly.shape) > 1: + poly = np.array(poly).flatten() + + string = "\t".join(str(i) for i in poly) + return string + + +def convert_label(label_dir, mode="gt", save_dir="./save_results/"): + if not os.path.exists(label_dir): + raise ValueError(f"The file {label_dir} does not exist!") + + assert label_dir != save_dir, "hahahhaha" + + label_file = open(label_dir, 'r') + data = label_file.readlines() + + gt_dict = {} + + for line in data: + try: + tmp = line.split('\t') + assert len(tmp) == 2, "" + except: + tmp = line.strip().split(' ') + + gt_lists = [] + + if tmp[0].split('/')[0] is not None: + img_path = tmp[0] + anno = json.loads(tmp[1]) + gt_collect = [] + for dic in anno: + #txt = dic['transcription'].replace(' ', '') # ignore blank + txt = dic['transcription'] + if 'score' in dic and float(dic['score']) < 0.5: + continue + if u'\u3000' in txt: txt = txt.replace(u'\u3000', u' ') + #while ' ' in txt: + # txt = txt.replace(' ', '') + poly = np.array(dic['points']).flatten() + if txt == "###": + txt_tag = 1 ## ignore 1 + else: + txt_tag = 0 + if mode == "gt": + gt_label = poly_to_string(poly) + "\t" + str( + txt_tag) + "\t" + txt + "\n" + else: + gt_label = poly_to_string(poly) + "\t" + txt + "\n" + + gt_lists.append(gt_label) + + gt_dict[img_path] = gt_lists + else: + continue + + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + for img_name in gt_dict.keys(): + save_name = img_name.split("/")[-1] + save_file = os.path.join(save_dir, save_name + ".txt") + with open(save_file, "w") as f: + f.writelines(gt_dict[img_name]) + + print("The convert label saved in {}".format(save_dir)) + + +if __name__ == "__main__": + + ppocr_label_gt = "/paddle/Datasets/chinese/test_set/Label_refine_310_V2.txt" + convert_label(ppocr_label_gt, "gt", "./save_gt_310_V2/") + + ppocr_label_gt = "./infer_results/ch_PPOCRV2_infer.txt" + convert_label(ppocr_label_gt_en, "pred", "./save_PPOCRV2_infer/") diff --git a/tools/end2end/draw_html.py b/tools/end2end/draw_html.py new file mode 100644 index 00000000..5f9a4719 --- /dev/null +++ b/tools/end2end/draw_html.py @@ -0,0 +1,52 @@ +import os + + +def draw_debug_img(html_path): + + err_cnt = 0 + with open(html_path, 'w') as html: + html.write('\n\n') + html.write('\n') + html.write( + "" + ) + image_list = [] + path = "./det_results/310_gt/" + #path = "infer_results/" + for i, filename in enumerate(sorted(os.listdir(path))): + if filename.endswith("txt"): continue + print(filename) + # The image path + base = "{}/{}".format(path, filename) + base_2 = "../PaddleOCR/det_results/ch_PPOCRV2_infer/{}".format( + filename) + base_3 = "../PaddleOCR/det_results/ch_ppocr_mobile_infer/{}".format( + filename) + + if True: + html.write("\n") + html.write(f'' % (base)) + html.write('' % + (base_2)) + html.write('' % + (base_3)) + + html.write("\n") + html.write('\n') + html.write('
{filename}\n GT') + html.write('GT\nPPOCRV2\nppocr_mobile\n
\n') + html.write('\n\n') + print("ok") + #print("all cnt: {}, err cnt: {}, acc: {}".format(len(imgs), err_cnt, 1.0 * (len(imgs) - err_cnt) / len(imgs))) + return + + +if __name__ == "__main__": + + html_path = "sys_visual_iou_310.html" + + draw_debug_img() diff --git a/tools/end2end/eval_end2end.py b/tools/end2end/eval_end2end.py new file mode 100644 index 00000000..88277786 --- /dev/null +++ b/tools/end2end/eval_end2end.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python +import os +import re +import sys +# import Polygon +import shapely +from shapely.geometry import Polygon +import numpy as np +from collections import defaultdict +import operator +import editdistance + +# reload(sys) +# sys.setdefaultencoding('utf-8') + + +def strQ2B(ustring): + rstring = "" + for uchar in ustring: + inside_code = ord(uchar) + if inside_code == 12288: + inside_code = 32 + elif (inside_code >= 65281 and inside_code <= 65374): + inside_code -= 65248 + rstring += chr(inside_code) + return rstring + + +def polygon_from_str(polygon_points): + """ + Create a shapely polygon object from gt or dt line. + """ + polygon_points = np.array(polygon_points).reshape(4, 2) + polygon = Polygon(polygon_points).convex_hull + return polygon + + +def polygon_iou(poly1, poly2): + """ + Intersection over union between two shapely polygons. + """ + if not poly1.intersects( + poly2): # this test is fast and can accelerate calculation + iou = 0 + else: + try: + inter_area = poly1.intersection(poly2).area + union_area = poly1.area + poly2.area - inter_area + iou = float(inter_area) / union_area + except shapely.geos.TopologicalError: + # except Exception as e: + # print(e) + print('shapely.geos.TopologicalError occured, iou set to 0') + iou = 0 + return iou + + +def ed(str1, str2): + return editdistance.eval(str1, str2) + + +def e2e_eval(gt_dir, res_dir): + print('start testing...') + iou_thresh = 0.5 + val_names = os.listdir(gt_dir) + num_gt_chars = 0 + gt_count = 0 + dt_count = 0 + hit = 0 + ed_sum = 0 + + for i, val_name in enumerate(val_names): + with open(os.path.join(gt_dir, val_name), encoding='utf-8') as f: + gt_lines = [o.strip() for o in f.readlines()] + gts = [] + ignore_masks = [] + for line in gt_lines: + parts = line.strip().split('\t') + # ignore illegal data + if len(parts) < 9: + continue + assert (len(parts) < 11) + if len(parts) == 9: + gts.append(parts[:8] + ['']) + else: + gts.append(parts[:8] + [parts[-1]]) + + ignore_masks.append(parts[8]) + + val_path = os.path.join(res_dir, val_name) + if not os.path.exists(val_path): + dt_lines = [] + else: + with open(val_path, encoding='utf-8') as f: + dt_lines = [o.strip() for o in f.readlines()] + dts = [] + for line in dt_lines: + # print(line) + parts = line.strip().split("\t") + assert (len(parts) < 10), "line error: {}".format(line) + if len(parts) == 8: + dts.append(parts + ['']) + else: + dts.append(parts) + + dt_match = [False] * len(dts) + gt_match = [False] * len(gts) + all_ious = defaultdict(tuple) + for index_gt, gt in enumerate(gts): + gt_coors = [float(gt_coor) for gt_coor in gt[0:8]] + gt_poly = polygon_from_str(gt_coors) + for index_dt, dt in enumerate(dts): + dt_coors = [float(dt_coor) for dt_coor in dt[0:8]] + dt_poly = polygon_from_str(dt_coors) + iou = polygon_iou(dt_poly, gt_poly) + if iou >= iou_thresh: + all_ious[(index_gt, index_dt)] = iou + sorted_ious = sorted( + all_ious.items(), key=operator.itemgetter(1), reverse=True) + sorted_gt_dt_pairs = [item[0] for item in sorted_ious] + + # matched gt and dt + for gt_dt_pair in sorted_gt_dt_pairs: + index_gt, index_dt = gt_dt_pair + if gt_match[index_gt] == False and dt_match[index_dt] == False: + gt_match[index_gt] = True + dt_match[index_dt] = True + # gt_str = strQ2B(gts[index_gt][8]).replace(" ", "") + # dt_str = strQ2B(dts[index_dt][8]).replace(" ", "") + gt_str = strQ2B(gts[index_gt][8]) + dt_str = strQ2B(dts[index_dt][8]) + if ignore_masks[index_gt] == '0': + ed_sum += ed(gt_str, dt_str) + num_gt_chars += len(gt_str) + if gt_str == dt_str: + hit += 1 + gt_count += 1 + dt_count += 1 + + # unmatched dt + for tindex, dt_match_flag in enumerate(dt_match): + if dt_match_flag == False: + dt_str = dts[tindex][8] + gt_str = '' + ed_sum += ed(dt_str, gt_str) + dt_count += 1 + + # unmatched gt + for tindex, gt_match_flag in enumerate(gt_match): + if gt_match_flag == False and ignore_masks[tindex] == '0': + dt_str = '' + gt_str = gts[tindex][8] + ed_sum += ed(gt_str, dt_str) + num_gt_chars += len(gt_str) + gt_count += 1 + + eps = 1e-9 + print('hit, dt_count, gt_count', hit, dt_count, gt_count) + precision = hit / (dt_count + eps) + recall = hit / (gt_count + eps) + fmeasure = 2.0 * precision * recall / (precision + recall + eps) + avg_edit_dist_img = ed_sum / len(val_names) + avg_edit_dist_field = ed_sum / (gt_count + eps) + character_acc = 1 - ed_sum / (num_gt_chars + eps) + + print('character_acc: %.2f' % (character_acc * 100) + "%") + print('avg_edit_dist_field: %.2f' % (avg_edit_dist_field)) + print('avg_edit_dist_img: %.2f' % (avg_edit_dist_img)) + print('precision: %.2f' % (precision * 100) + "%") + print('recall: %.2f' % (recall * 100) + "%") + print('fmeasure: %.2f' % (fmeasure * 100) + "%") + + +if __name__ == '__main__': + # if len(sys.argv) != 3: + # print("python3 ocr_e2e_eval.py gt_dir res_dir") + # exit(-1) + # gt_folder = sys.argv[1] + # pred_folder = sys.argv[2] + gt_folder = sys.argv[1] + pred_folder = sys.argv[2] + e2e_eval(gt_folder, pred_folder) diff --git a/tools/end2end/readme.md b/tools/end2end/readme.md new file mode 100644 index 00000000..257a6f38 --- /dev/null +++ b/tools/end2end/readme.md @@ -0,0 +1,69 @@ + +# 简介 + +`tools/end2end`目录下存放了文本检测+文本识别pipeline串联预测的指标评测代码以及可视化工具。本节介绍文本检测+文本识别的端对端指标评估方式。 + + +## 端对端评测步骤 + +**步骤一:** + +运行`tools/infer/predict_system.py`,得到保存的结果: + +``` +python3 tools/infer/predict_system.py --det_model_dir=./ch_PP-OCRv2_det_infer/ --rec_model_dir=./ch_PP-OCRv2_rec_infer/ --image_dir=./datasets/img_dir/ --draw_img_save_dir=./ch_PP-OCRv2_results/ --is_visualize=True +``` + +文本检测识别可视化图默认保存在`./ch_PP-OCRv2_results/`目录下,预测结果默认保存在`./ch_PP-OCRv2_results/results.txt`中,格式如下: +``` +all-sum-510/00224225.jpg [{"transcription": "超赞", "points": [[8.0, 48.0], [157.0, 44.0], [159.0, 115.0], [10.0, 119.0]], "score": "0.99396634"}, {"transcription": "中", "points": [[202.0, 152.0], [230.0, 152.0], [230.0, 163.0], [202.0, 163.0]], "score": "0.09310734"}, {"transcription": "58.0m", "points": [[196.0, 192.0], [444.0, 192.0], [444.0, 240.0], [196.0, 240.0]], "score": "0.44041982"}, {"transcription": "汽配", "points": [[55.0, 263.0], [95.0, 263.0], [95.0, 281.0], [55.0, 281.0]], "score": "0.9986651"}, {"transcription": "成总店", "points": [[120.0, 262.0], [176.0, 262.0], [176.0, 283.0], [120.0, 283.0]], "score": "0.9929402"}, {"transcription": "K", "points": [[237.0, 286.0], [311.0, 286.0], [311.0, 345.0], [237.0, 345.0]], "score": "0.6074794"}, {"transcription": "88:-8", "points": [[203.0, 405.0], [477.0, 414.0], [475.0, 459.0], [201.0, 450.0]], "score": "0.7106863"}] +``` + + +**步骤二:** + +将步骤一保存的数据转换为端对端评测需要的数据格式: +修改 `tools/convert_ppocr_label.py`中的代码,convert_label函数中设置输入标签路径,Mode,保存标签路径等,对预测数据的GTlabel和预测结果的label格式进行转换。 + +``` +ppocr_label_gt = "gt_label.txt" +convert_label(ppocr_label_gt, "gt", "./save_gt_label/") + +ppocr_label_gt = "./infer_results/ch_PPOCRV2_infer.txt" +convert_label(ppocr_label_gt_en, "pred", "./save_PPOCRV2_infer/") +``` + +运行`convert_ppocr_label.py`: +``` +python3 tools/convert_ppocr_label.py +``` + +得到如下结果: +``` +├── ./save_gt_label/ +├── ./save_PPOCRV2_infer/ +``` + +**步骤三:** + +执行端对端评测,运行`tools/eval_end2end.py`计算端对端指标,运行方式如下: + +``` +python3 tools/eval_end2end.py "gt_label_dir" "predict_label_dir" +``` + +比如: + +``` +python3 tools/eval_end2end.py ./save_gt_label/ ./save_PPOCRV2_infer/ +``` +将得到如下结果,fmeasure为主要关注的指标: +``` +hit, dt_count, gt_count 1557 2693 3283 +character_acc: 61.77% +avg_edit_dist_field: 3.08 +avg_edit_dist_img: 51.82 +precision: 57.82% +recall: 47.43% +fmeasure: 52.11% +``` diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py index 8d674809..e9aff6d2 100755 --- a/tools/infer/predict_system.py +++ b/tools/infer/predict_system.py @@ -27,6 +27,7 @@ import numpy as np import time import logging from PIL import Image +import json import tools.infer.utility as utility import tools.infer.predict_rec as predict_rec import tools.infer.predict_det as predict_det @@ -121,11 +122,31 @@ def sorted_boxes(dt_boxes): return _boxes +def save_results_to_txt(results, path): + if os.path.isdir(path): + if not os.path.exists(path): + os.makedirs(path) + with open(os.path.join(path, "results.txt"), 'w') as f: + f.writelines(results) + f.close() + logger.info("The results will be saved in {}".format( + os.path.join(path, "results.txt"))) + else: + draw_img_save = os.path.dirname(path) + if not os.path.exists(draw_img_save): + os.makedirs(draw_img_save) + + with open(path, 'w') as f: + f.writelines(results) + f.close() + logger.info("The results will be saved in {}".format(path)) + + def main(args): image_file_list = get_image_file_list(args.image_dir) image_file_list = image_file_list[args.process_id::args.total_process_num] text_sys = TextSystem(args) - is_visualize = True + is_visualize = args.is_visualize font_path = args.vis_font_path drop_score = args.drop_score @@ -139,6 +160,7 @@ def main(args): cpu_mem, gpu_mem, gpu_util = 0, 0, 0 _st = time.time() count = 0 + save_res = [] for idx, image_file in enumerate(image_file_list): img, flag = check_and_read_gif(image_file) @@ -152,6 +174,21 @@ def main(args): elapse = time.time() - starttime total_time += elapse + # save results + preds = [] + dt_num = len(dt_boxes) + for dno in range(dt_num): + text, score = rec_res[dno] + if score >= drop_score: + preds.append({ + "transcription": text, + "points": np.array(dt_boxes[dno]).tolist() + }) + text_str = "%s, %.3f" % (text, score) + save_res.append(image_file + '\t' + json.dumps( + preds, ensure_ascii=False) + '\n') + + # print predicted results logger.debug( str(idx) + " Predict time of %s: %.3fs" % (image_file, elapse)) for text, score in rec_res: @@ -180,6 +217,9 @@ def main(args): logger.debug("The visualized image saved in {}".format( os.path.join(draw_img_save_dir, os.path.basename(image_file)))) + # The predicted results will be saved in os.path.join(os.draw_img_save_dir, "results.txt") + save_results_to_txt(save_res, args.draw_img_save_dir) + logger.info("The predict total time is {}".format(time.time() - _st)) if args.benchmark: text_sys.text_detector.autolog.report() diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 33ed6212..7b7b81e3 100644 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -114,6 +114,7 @@ def init_args(): # parser.add_argument( "--draw_img_save_dir", type=str, default="./inference_results") + parser.add_argument("--is_visualize", type=str2bool, default=True) parser.add_argument("--save_crop_res", type=str2bool, default=False) parser.add_argument("--crop_res_save_dir", type=str, default="./output") -- GitLab